Unify sglang coding style (#2856)
Co-authored-by: Lin, Soga <soga.lin@amd.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
@@ -19,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
@@ -28,6 +27,8 @@ else:
|
||||
|
||||
import logging
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -99,7 +100,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
@@ -163,7 +164,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -47,6 +46,8 @@ from sglang.srt.utils import (
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -162,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
||||
# Disable marlin for ROCm
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
self.use_marlin = False
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
@@ -274,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# activation_scheme: dynamic
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
@@ -331,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
@@ -568,7 +569,7 @@ class Fp8MoEMethod:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
@@ -595,7 +596,7 @@ class Fp8MoEMethod:
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
@@ -617,8 +618,8 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
if is_hip():
|
||||
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||
if is_hip_:
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
@@ -629,7 +630,7 @@ class Fp8MoEMethod:
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
elif get_bool_env_var("MOE_PADDING"):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
@@ -671,7 +672,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@@ -721,8 +722,8 @@ class Fp8MoEMethod:
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
if bool(int(os.getenv("CK_MOE", "0"))):
|
||||
if is_hip_:
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
@@ -733,7 +734,7 @@ class Fp8MoEMethod:
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
elif bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||
elif get_bool_env_var("MOE_PADDING"):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
@@ -777,7 +778,7 @@ class Fp8MoEMethod:
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
import ater
|
||||
from ater.fused_moe import fused_experts_ck
|
||||
|
||||
|
||||
Reference in New Issue
Block a user