From d1da58e275e377db3b131e9988b432a37aa232c0 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 11 Mar 2025 18:12:56 -0700 Subject: [PATCH] unify is_cuda and is_hip (#4321) --- python/sglang/srt/custom_op.py | 8 +++-- .../device_communicators/custom_all_reduce.py | 35 ++++++++++--------- .../attention/triton_ops/decode_attention.py | 12 +++---- .../triton_ops/double_sparsity_attention.py | 6 ++-- .../attention/triton_ops/extend_attention.py | 8 ++--- .../triton_ops/rocm_mla_decode_rope.py | 6 ++-- .../sglang/srt/layers/moe/ep_moe/kernels.py | 3 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 ++- .../layers/moe/fused_moe_triton/fused_moe.py | 18 +++++----- .../srt/layers/moe/fused_moe_triton/layer.py | 14 ++++---- python/sglang/srt/layers/quantization/fp8.py | 26 +++++++------- .../srt/layers/quantization/fp8_kernel.py | 14 ++++---- .../srt/layers/quantization/fp8_utils.py | 8 ++--- .../srt/layers/quantization/w8a8_fp8.py | 4 ++- .../srt/model_executor/cuda_graph_runner.py | 4 +-- python/sglang/srt/models/deepseek_nextn.py | 6 ++-- python/sglang/srt/models/deepseek_v2.py | 8 ++--- python/sglang/srt/utils.py | 12 ++++--- 18 files changed, 104 insertions(+), 92 deletions(-) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index d770e9c08..2e066efe8 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -1,8 +1,10 @@ import torch from torch import nn -_is_cuda = torch.cuda.is_available() and torch.version.cuda -_is_rocm = torch.cuda.is_available() and torch.version.hip +from sglang.srt.utils import is_cuda, is_hip + +_is_cuda = is_cuda() +_is_hip = is_hip() class CustomOp(nn.Module): @@ -34,7 +36,7 @@ class CustomOp(nn.Module): def dispatch_forward(self): if _is_cuda: return self.forward_cuda - elif _is_rocm: + elif _is_hip: return self.forward_hip else: return self.forward_native diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index efd943b0e..2d5f9ada4 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -22,15 +22,16 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip logger = logging.getLogger(__name__) -is_hip_ = is_hip() +_is_cuda = is_cuda() +_is_hip = is_hip() -if is_cuda(): +if _is_cuda: try: import pynvml except ImportError as e: logger.warning("Failed to import pynvml with %r", e) -if is_hip_: +if _is_hip: try: from amdsmi import ( AmdSmiException, @@ -43,7 +44,7 @@ if is_hip_: logger.warning("Failed to import amdsmi with %r", e) try: - if ops.use_vllm_custom_allreduce and not is_hip_: + if ops.use_vllm_custom_allreduce and not _is_hip: # Use vLLM custom allreduce ops.meta_size() else: @@ -63,7 +64,7 @@ _R = TypeVar("_R") def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - if is_hip_: + if _is_hip: try: amdsmi_init() return fn(*args, **kwargs) @@ -81,7 +82,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @with_nvml_context def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: - if is_hip_: + if _is_hip: """ query if the set of gpus are fully connected by xgmi (1 hop) """ @@ -145,7 +146,7 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _MAX_CAR_SIZE = 8192 * 1024 - if is_hip_: + if _is_hip: # crossover is at 16MB buffer size for ROCm _MAX_CAR_SIZE = 2 * 8192 * 1024 @@ -229,7 +230,7 @@ class CustomAllreduce: # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - if is_cuda() or is_hip_: + if _is_cuda or _is_hip: full_nvlink = is_full_nvlink(physical_device_ids, world_size) if world_size > 2 and not full_nvlink: @@ -243,7 +244,7 @@ class CustomAllreduce: # this is expensive to compute at the first time # then we cache the result # On AMD GPU, p2p is always enabled between XGMI connected GPUs - if not is_hip_ and not _can_p2p(rank, world_size): + if not _is_hip and not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " @@ -256,7 +257,7 @@ class CustomAllreduce: self.world_size = world_size self.full_nvlink = full_nvlink - if ops.use_vllm_custom_allreduce and not is_hip_: + if ops.use_vllm_custom_allreduce and not _is_hip: # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. @@ -279,7 +280,7 @@ class CustomAllreduce: ) ops.register_buffer(self._ptr, self.buffer_ptrs) else: - if is_hip_: + if _is_hip: # meta data buffers need to be "uncached" for signal on MI200 self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) self.buffer = torch.empty( @@ -418,7 +419,7 @@ class CustomAllreduce: ops.register_buffer(self._ptr, inp, handles, offsets) def register_graph_buffers(self): - if is_hip_: + if _is_hip: handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) @@ -454,12 +455,12 @@ class CustomAllreduce: return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if ops.use_vllm_custom_allreduce and not is_hip_: + if ops.use_vllm_custom_allreduce and not _is_hip: if self.world_size == 2 or self.full_nvlink: return inp_size < self.max_size return False - if is_hip_: + if _is_hip: if self.full_nvlink: if self.world_size == 8: if self.MSCCL: @@ -532,7 +533,7 @@ class CustomAllreduce: return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - if is_hip_: + if _is_hip: return self.all_reduce_reg(input) else: return self.all_reduce(input, registered=True) @@ -541,7 +542,7 @@ class CustomAllreduce: # allreduce is out-of-place. return torch.empty_like(input) else: - if is_hip_: + if _is_hip: # note: outside of cuda graph context, # custom allreduce incurs a cost of cudaMemcpy, which should # be small(<=1% of overall latency) compared to the performance @@ -556,7 +557,7 @@ class CustomAllreduce: if ops.use_vllm_custom_allreduce: self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.buffer_ptrs) - elif is_cuda(): + elif _is_cuda: self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.tmp_result_buffer_ptrs) self.free_shared_buffer(self.barrier_in_ptrs) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 3b4853e40..a9ab44546 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -27,7 +27,7 @@ import triton.language as tl from sglang.srt.utils import is_hip -is_hip_ = is_hip() +_is_hip = is_hip() logger = logging.getLogger(__name__) @@ -180,7 +180,7 @@ def _decode_att_m_fwd( ): BLOCK = 64 # [TODO] work around SGPR limit on MI3xx - if is_hip_: + if _is_hip: BLOCK = 8 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] @@ -195,7 +195,7 @@ def _decode_att_m_fwd( num_warps = 4 else: num_warps = 2 - if is_hip_: + if _is_hip: num_warps = 1 BLOCK_DMODEL = triton.next_power_of_2(Lk) @@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd( Lv = v_buffer.shape[-1] # [TODO] work around shmem limit on MI3xx - if is_hip_ and Lk >= 576: + if _is_hip and Lk >= 576: BLOCK = 16 if Lk == 576: @@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd( extra_kargs = {} num_stages = 2 - if is_hip_: + if _is_hip: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} @@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd( NUM_KV_SPLITS = num_kv_splits extra_kargs = {} - if is_hip_: + if _is_hip: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} diff --git a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py index db0fb6b4d..459e43b48 100644 --- a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py @@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available() if is_cuda_available: CUDA_CAPABILITY = torch.cuda.get_device_capability() -is_hip_ = is_hip() +_is_hip = is_hip() if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 @@ -1032,7 +1032,7 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - if is_hip_: + if _is_hip: BLOCK_M, BLOCK_N = (64, 64) num_warps = 4 @@ -1062,7 +1062,7 @@ def extend_attention_fwd( num_stages = 1 extra_kargs = {} - if is_hip_: + if _is_hip: extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} _fwd_kernel[grid]( diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 079c8cfd9..f8f8f1088 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available() if is_cuda_available: CUDA_CAPABILITY = torch.cuda.get_device_capability() -is_hip_ = is_hip() +_is_hip = is_hip() @triton.jit @@ -330,7 +330,7 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - if is_hip_: + if _is_hip: BLOCK_M, BLOCK_N = (64, 64) num_warps = 4 @@ -364,7 +364,7 @@ def extend_attention_fwd( num_stages = 1 extra_kargs = {} - if is_hip_: + if _is_hip: extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} _fwd_kernel[grid]( @@ -403,7 +403,7 @@ def extend_attention_fwd( Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, - STORE_TRANSPOSE=is_hip_, + STORE_TRANSPOSE=_is_hip, num_warps=num_warps, num_stages=num_stages, **extra_kargs, diff --git a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py index 915f4ef92..89b17ce2d 100644 --- a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +++ b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py @@ -32,7 +32,7 @@ def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" -is_hip_ = is_hip() +_is_hip = is_hip() @triton.jit @@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope( BLOCK = 32 # # [TODO] work around shmem limit on MI3xx - # if is_hip_ and kv_lora_rank >= 576: + # if _is_hip and kv_lora_rank >= 576: # BLOCK = 16 qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank @@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope( extra_kargs = {} num_stages = 2 - if is_hip_: + if _is_hip: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 85f791889..b455c05c3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -6,8 +6,9 @@ import triton import triton.language as tl from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 +from sglang.srt.utils import is_cuda -_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_cuda = is_cuda() if _is_cuda: from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 1e8e1c4d3..22eac0496 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs logger = logging.getLogger(__name__) +_is_hip = is_hip() + class GroupedGemmRunner(torch.nn.Module): flashinfer_gemm_warpper = None @@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod): # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If rocm, use float8_e4m3fnuz as dtype - fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 71c3d8ebe..40dbdc35d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -23,10 +23,11 @@ from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, get_device_name, + is_cuda, is_hip, ) -is_hip_ = is_hip() +_is_hip = is_hip() logger = logging.getLogger(__name__) @@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool( int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) ) -_is_cuda = torch.cuda.is_available() and torch.version.cuda -_is_rocm = torch.cuda.is_available() and torch.version.hip +_is_cuda = is_cuda() if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul @@ -46,7 +46,7 @@ if _is_cuda: sglang_per_token_group_quant_fp8, ) -if _is_cuda or _is_rocm: +if _is_cuda or _is_hip: from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size @@ -679,7 +679,7 @@ def get_default_config( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 2 if is_hip_ else 4, + "num_stages": 2 if _is_hip else 4, } if M <= E: config = { @@ -688,7 +688,7 @@ def get_default_config( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 if is_hip_ else 4, + "num_stages": 2 if _is_hip else 4, } else: # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] @@ -698,7 +698,7 @@ def get_default_config( "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 2 if is_hip_ else 3, + "num_stages": 2 if _is_hip else 3, } else: config = { @@ -976,7 +976,7 @@ def fused_experts_impl( if ( not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None - or (is_hip_ and get_bool_env_var("CK_MOE")) + or (_is_hip and get_bool_env_var("CK_MOE")) ): padded_size = 0 @@ -1131,7 +1131,7 @@ def fused_experts_impl( if no_combine: pass - elif is_hip_: + elif _is_hip: ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 2c3f722ce..42f697fbf 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -27,9 +27,9 @@ else: import logging -is_hip_ = is_hip() +_is_hip = is_hip() -if is_hip_: +if _is_hip: from aiter import ck_moe logger = logging.getLogger(__name__) @@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if is_hip_ and get_bool_env_var("CK_MOE"): + if _is_hip and get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( permute_weight(layer.w13_weight.data), requires_grad=False, @@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): correction_bias=correction_bias, ) - if is_hip_ and get_bool_env_var("CK_MOE"): + if _is_hip and get_bool_env_var("CK_MOE"): assert not no_combine, "unsupported" return ck_moe( x, @@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module): # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD) - if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 2.0 # this is needed for compressed-tensors only @@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module): quant_method = getattr(param, "quant_method", None) if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD) - if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 0.5 self._load_per_channel_weight_scale( @@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module): ) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD) - if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 2.0 self._load_per_tensor_weight_scale( diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 79092813a..bff0fe96e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -54,9 +54,9 @@ from sglang.srt.utils import ( ACTIVATION_SCHEMES = ["static", "dynamic"] -is_hip_ = is_hip() +_is_hip = is_hip() -if is_hip_: +if _is_hip: from aiter.fused_moe_bf16_asm import asm_moe from aiter.ops.shuffle import shuffle_weight @@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase): # kernel for fast weight-only FP8 quantization self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") # Disable marlin for ROCm - if is_hip_: + if _is_hip: self.use_marlin = False self.block_quant = self.quant_config.weight_block_size is not None @@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase): # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip_: + if _is_hip: # activation_scheme: dynamic weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase): weight = layer.weight weight_scale = layer.weight_scale # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip_: + if _is_hip: weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, @@ -563,7 +563,7 @@ class Fp8MoEMethod: layer.register_parameter("w2_weight_scale", w2_weight_scale) if ( - is_hip_ + _is_hip ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling w13_weight_scale1 = torch.nn.Parameter( @@ -630,7 +630,7 @@ class Fp8MoEMethod: # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip_: + if _is_hip: # activation_scheme: dynamic w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.w13_weight, @@ -667,7 +667,7 @@ class Fp8MoEMethod: # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) - fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -689,7 +689,7 @@ class Fp8MoEMethod: layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - if is_hip_: + if _is_hip: self.process_weights_hip_scale_padding(layer) return @@ -721,7 +721,7 @@ class Fp8MoEMethod: ) # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip_: + if _is_hip: # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( @@ -771,7 +771,7 @@ class Fp8MoEMethod: max_w13_scales, requires_grad=False ) - if is_hip_: + if _is_hip: self.process_weights_hip_scale_padding(layer) return @@ -882,7 +882,7 @@ class Fp8MoEMethod: correction_bias=correction_bias, ) - if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): # TODO: add triton kernel and add check get_bool_env_var("CK_MOE") assert not no_combine, f"{no_combine=} is not supported." return asm_moe( @@ -895,7 +895,7 @@ class Fp8MoEMethod: layer.w2_weight_scale1, activation=activation, ) - if is_hip_ and get_bool_env_var("CK_MOE"): + if _is_hip and get_bool_env_var("CK_MOE"): # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being. assert ( activation == "silu" diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 4ffeab8ba..1b61575ca 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -22,12 +22,12 @@ import torch import triton import triton.language as tl -from sglang.srt.utils import get_device_core_count, get_device_name, is_hip +from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip -is_hip_ = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn -_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_cuda = is_cuda() if _is_cuda: import deep_gemm from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 @@ -157,7 +157,7 @@ def per_token_group_quant_fp8( finfo = torch.finfo(dtype) fp8_max = finfo.max - if is_hip_: + if _is_hip: fp8_max = 224.0 fp8_min = -fp8_max @@ -332,7 +332,7 @@ def static_quant_fp8( finfo = torch.finfo(dtype) fp8_max = finfo.max - if is_hip_: + if _is_hip: fp8_max = 224.0 fp8_min = -fp8_max @@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul( else: kernel = ( _w8a8_block_fp8_matmul_unrolledx4 - if (is_hip_ == True and num_workgroups <= get_device_core_count()) + if (_is_hip == True and num_workgroups <= get_device_core_count()) else _w8a8_block_fp8_matmul ) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index b7b2f2b89..42c8c1371 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -17,8 +17,8 @@ from sglang.srt.utils import ( use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") -is_hip_ = is_hip() -if is_hip_ and get_bool_env_var("CK_MOE"): +_is_hip = is_hip() +if _is_hip and get_bool_env_var("CK_MOE"): from aiter import gemm_a8w8_blockscale _is_cuda = is_cuda() @@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear( output = fp8_blockwise_scaled_mm( q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype ) - elif is_hip_ and get_bool_env_var("CK_MOE"): + elif _is_hip and get_bool_env_var("CK_MOE"): q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=False ) @@ -142,7 +142,7 @@ def input_to_float8( min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) fp8_max = finfo.max - if is_hip_: + if _is_hip: fp8_max = 224.0 scale = fp8_max / amax x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max) diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 0adedc68f..240d86927 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import ( ) from sglang.srt.utils import is_hip +_is_hip = is_hip() + class W8A8Fp8Config(QuantizationConfig): """Config class for W8A8 FP8 Quantization. @@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = layer.weight weight_scale = layer.weight_scale.detach() - if is_hip(): + if _is_hip: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 83c2d88f0..5ae558056 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.utils import is_hip -is_hip_ = is_hip() +_is_hip = is_hip() if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): else: capture_bs = list(range(1, 33)) - if is_hip_: + if _is_hip: capture_bs += [i * 8 for i in range(21, 33)] if max(capture_bs) > model_runner.req_to_token_pool.size: diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index c1c99be54..753e1ba1f 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.utils import add_prefix, is_hip -is_hip_ = is_hip() +_is_hip = is_hip() class DeepseekModelNextN(nn.Module): @@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): weight_block_size = self.quant_config.weight_block_size if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if is_hip_: + if _is_hip: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=w, weight_scale=self_attn.kv_b_proj.weight_scale_inv, @@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): and self_attn.w_scale is None ): self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if is_hip_: + if _is_hip: self_attn.w_scale *= 2.0 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 40f6799a1..d0ca14feb 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda_available, is_hip -is_hip_ = is_hip() +_is_hip = is_hip() if is_cuda_available(): from sgl_kernel import bmm_fp8 @@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module): if no_absorb(): return self.forward_normal(positions, hidden_states, forward_batch) else: - if is_hip_: + if _is_hip: if ( os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" and forward_batch.forward_mode.is_decode() @@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module): weight_block_size = self.quant_config.weight_block_size if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if is_hip_: + if _is_hip: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=w, weight_scale=self_attn.kv_b_proj.weight_scale_inv, @@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module): and self_attn.w_scale is None ): self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if is_hip_: + if _is_hip: self_attn.w_scale *= 2.0 def get_embed_and_head(self): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8bfdbc0ed..f0eb2495c 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -72,13 +72,17 @@ show_time_cost = False time_infos = {} +# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip def is_hip() -> bool: - """Return whether it is HIP on the AMD ROCm platform.""" return torch.version.hip is not None +def is_rocm() -> bool: + return torch.cuda.is_available() and torch.version.hip + + def is_cuda(): - return hasattr(torch, "cuda") and torch.version.cuda is not None + return torch.cuda.is_available() and torch.version.cuda def is_cuda_alike(): @@ -100,11 +104,11 @@ def is_flashinfer_available(): """ if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"): return False - return torch.cuda.is_available() and torch.version.cuda + return is_cuda() def is_cuda_available(): - return torch.cuda.is_available() and torch.version.cuda + return is_cuda() def enable_show_time_cost():