unify is_cuda and is_hip (#4321)
This commit is contained in:
@@ -6,8 +6,9 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
|
||||
@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
flashinfer_gemm_warpper = None
|
||||
@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
|
||||
@@ -23,10 +23,11 @@ from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
get_device_name,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool(
|
||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
)
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
@@ -46,7 +46,7 @@ if _is_cuda:
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
if _is_cuda or _is_rocm:
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
|
||||
@@ -679,7 +679,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
"num_stages": 2 if _is_hip else 4,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
@@ -688,7 +688,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
"num_stages": 2 if _is_hip else 4,
|
||||
}
|
||||
else:
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
||||
@@ -698,7 +698,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_ else 3,
|
||||
"num_stages": 2 if _is_hip else 3,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
@@ -976,7 +976,7 @@ def fused_experts_impl(
|
||||
if (
|
||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
||||
or block_shape is not None
|
||||
or (is_hip_ and get_bool_env_var("CK_MOE"))
|
||||
or (_is_hip and get_bool_env_var("CK_MOE"))
|
||||
):
|
||||
padded_size = 0
|
||||
|
||||
@@ -1131,7 +1131,7 @@ def fused_experts_impl(
|
||||
|
||||
if no_combine:
|
||||
pass
|
||||
elif is_hip_:
|
||||
elif _is_hip:
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
|
||||
@@ -27,9 +27,9 @@ else:
|
||||
|
||||
import logging
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
from aiter import ck_moe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
assert not no_combine, "unsupported"
|
||||
return ck_moe(
|
||||
x,
|
||||
@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 2.0
|
||||
|
||||
# this is needed for compressed-tensors only
|
||||
@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module):
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 0.5
|
||||
|
||||
self._load_per_channel_weight_scale(
|
||||
@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 2.0
|
||||
|
||||
self._load_per_tensor_weight_scale(
|
||||
|
||||
Reference in New Issue
Block a user