diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index e1064bcda..d95498377 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,6 +1,5 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py -import os from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple @@ -19,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs +from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -28,6 +27,8 @@ else: import logging +is_hip_ = is_hip() + logger = logging.getLogger(__name__) @@ -99,7 +100,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): + if is_hip_ and get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( permute_weight(layer.w13_weight.data), requires_grad=False, @@ -163,7 +164,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): correction_bias=correction_bias, ) - if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): + if is_hip_ and get_bool_env_var("CK_MOE"): import ater from ater.fused_moe import fused_experts_ck diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 22a43675b..d16a3b0c2 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -import os from typing import Any, Callable, Dict, List, Optional import torch @@ -47,6 +46,8 @@ from sglang.srt.utils import ( ACTIVATION_SCHEMES = ["static", "dynamic"] +is_hip_ = is_hip() + logger = logging.getLogger(__name__) @@ -162,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase): # kernel for fast weight-only FP8 quantization self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") # Disable marlin for ROCm - if is_hip(): + if is_hip_: self.use_marlin = False self.block_quant = self.quant_config.weight_block_size is not None @@ -274,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase): # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: # activation_scheme: dynamic weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -331,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase): weight_scale = layer.weight_scale # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, @@ -568,7 +569,7 @@ class Fp8MoEMethod: # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: # activation_scheme: dynamic w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.w13_weight, @@ -595,7 +596,7 @@ class Fp8MoEMethod: # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) - fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -617,8 +618,8 @@ class Fp8MoEMethod: layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - if is_hip(): - if bool(int(os.getenv("CK_MOE", "0"))): + if is_hip_: + if get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( permute_weight(layer.w13_weight.data), requires_grad=False, @@ -629,7 +630,7 @@ class Fp8MoEMethod: requires_grad=False, ) torch.cuda.empty_cache() - elif bool(int(os.getenv("MOE_PADDING", "0"))): + elif get_bool_env_var("MOE_PADDING"): # If ROCm, apply weight padding (min. Mem channel contention) only if set layer.w13_weight = torch.nn.Parameter( F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), @@ -671,7 +672,7 @@ class Fp8MoEMethod: ) # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( @@ -721,8 +722,8 @@ class Fp8MoEMethod: max_w13_scales, requires_grad=False ) - if is_hip(): - if bool(int(os.getenv("CK_MOE", "0"))): + if is_hip_: + if get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( permute_weight(layer.w13_weight.data), requires_grad=False, @@ -733,7 +734,7 @@ class Fp8MoEMethod: requires_grad=False, ) torch.cuda.empty_cache() - elif bool(int(os.getenv("MOE_PADDING", "0"))): + elif get_bool_env_var("MOE_PADDING"): # If ROCm, apply weight padding (min. Mem channel contention) only if set layer.w13_weight = torch.nn.Parameter( F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), @@ -777,7 +778,7 @@ class Fp8MoEMethod: correction_bias=correction_bias, ) - if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): + if is_hip_ and get_bool_env_var("CK_MOE"): import ater from ater.fused_moe import fused_experts_ck