unify is_cuda and is_hip (#4321)

This commit is contained in:
Yineng Zhang
2025-03-11 18:12:56 -07:00
committed by GitHub
parent 1cf63485c1
commit d1da58e275
18 changed files with 104 additions and 92 deletions

View File

@@ -6,8 +6,9 @@ import triton
import triton.language as tl
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import is_cuda
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,

View File

@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
logger = logging.getLogger(__name__)
_is_hip = is_hip()
class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

View File

@@ -23,10 +23,11 @@ from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_device_name,
is_cuda,
is_hip,
)
is_hip_ = is_hip()
_is_hip = is_hip()
logger = logging.getLogger(__name__)
@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
@@ -46,7 +46,7 @@ if _is_cuda:
sglang_per_token_group_quant_fp8,
)
if _is_cuda or _is_rocm:
if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@@ -679,7 +679,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2 if is_hip_ else 4,
"num_stages": 2 if _is_hip else 4,
}
if M <= E:
config = {
@@ -688,7 +688,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2 if is_hip_ else 4,
"num_stages": 2 if _is_hip else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
@@ -698,7 +698,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2 if is_hip_ else 3,
"num_stages": 2 if _is_hip else 3,
}
else:
config = {
@@ -976,7 +976,7 @@ def fused_experts_impl(
if (
not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None
or (is_hip_ and get_bool_env_var("CK_MOE"))
or (_is_hip and get_bool_env_var("CK_MOE"))
):
padded_size = 0
@@ -1131,7 +1131,7 @@ def fused_experts_impl(
if no_combine:
pass
elif is_hip_:
elif _is_hip:
ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],

View File

@@ -27,9 +27,9 @@ else:
import logging
is_hip_ = is_hip()
_is_hip = is_hip()
if is_hip_:
if _is_hip:
from aiter import ck_moe
logger = logging.getLogger(__name__)
@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip_ and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias,
)
if is_hip_ and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
assert not no_combine, "unsupported"
return ck_moe(
x,
@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only
@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module):
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale(
@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module):
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale(