unify is_cuda and is_hip (#4321)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
@@ -34,7 +36,7 @@ class CustomOp(nn.Module):
|
||||
def dispatch_forward(self):
|
||||
if _is_cuda:
|
||||
return self.forward_cuda
|
||||
elif _is_rocm:
|
||||
elif _is_hip:
|
||||
return self.forward_hip
|
||||
else:
|
||||
return self.forward_native
|
||||
|
||||
@@ -22,15 +22,16 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if is_cuda():
|
||||
if _is_cuda:
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
@@ -43,7 +44,7 @@ if is_hip_:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
# Use vLLM custom allreduce
|
||||
ops.meta_size()
|
||||
else:
|
||||
@@ -63,7 +64,7 @@ _R = TypeVar("_R")
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
@@ -81,7 +82,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
@@ -145,7 +146,7 @@ def is_weak_contiguous(inp: torch.Tensor):
|
||||
class CustomAllreduce:
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
_MAX_CAR_SIZE = 8192 * 1024
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# crossover is at 16MB buffer size for ROCm
|
||||
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
||||
|
||||
@@ -229,7 +230,7 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
if is_cuda() or is_hip_:
|
||||
if _is_cuda or _is_hip:
|
||||
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
||||
|
||||
if world_size > 2 and not full_nvlink:
|
||||
@@ -243,7 +244,7 @@ class CustomAllreduce:
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not is_hip_ and not _can_p2p(rank, world_size):
|
||||
if not _is_hip and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
@@ -256,7 +257,7 @@ class CustomAllreduce:
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
|
||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
@@ -279,7 +280,7 @@ class CustomAllreduce:
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
else:
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# meta data buffers need to be "uncached" for signal on MI200
|
||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||
self.buffer = torch.empty(
|
||||
@@ -418,7 +419,7 @@ class CustomAllreduce:
|
||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
@@ -454,12 +455,12 @@ class CustomAllreduce:
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
if self.full_nvlink:
|
||||
if self.world_size == 8:
|
||||
if self.MSCCL:
|
||||
@@ -532,7 +533,7 @@ class CustomAllreduce:
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
return self.all_reduce_reg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=True)
|
||||
@@ -541,7 +542,7 @@ class CustomAllreduce:
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
@@ -556,7 +557,7 @@ class CustomAllreduce:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
elif is_cuda():
|
||||
elif _is_cuda:
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs)
|
||||
|
||||
@@ -27,7 +27,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -180,7 +180,7 @@ def _decode_att_m_fwd(
|
||||
):
|
||||
BLOCK = 64
|
||||
# [TODO] work around SGPR limit on MI3xx
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
BLOCK = 8
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
Lk = k_buffer.shape[-1]
|
||||
@@ -195,7 +195,7 @@ def _decode_att_m_fwd(
|
||||
num_warps = 4
|
||||
else:
|
||||
num_warps = 2
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
num_warps = 1
|
||||
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd(
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
# [TODO] work around shmem limit on MI3xx
|
||||
if is_hip_ and Lk >= 576:
|
||||
if _is_hip and Lk >= 576:
|
||||
BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
|
||||
|
||||
extra_kargs = {}
|
||||
num_stages = 2
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd(
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
@@ -1032,7 +1032,7 @@ def extend_attention_fwd(
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
num_warps = 4
|
||||
|
||||
@@ -1062,7 +1062,7 @@ def extend_attention_fwd(
|
||||
num_stages = 1
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
_fwd_kernel[grid](
|
||||
|
||||
@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -330,7 +330,7 @@ def extend_attention_fwd(
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
num_warps = 4
|
||||
|
||||
@@ -364,7 +364,7 @@ def extend_attention_fwd(
|
||||
num_stages = 1
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
_fwd_kernel[grid](
|
||||
@@ -403,7 +403,7 @@ def extend_attention_fwd(
|
||||
Lv=Lv,
|
||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||
STORE_TRANSPOSE=is_hip_,
|
||||
STORE_TRANSPOSE=_is_hip,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
**extra_kargs,
|
||||
|
||||
@@ -32,7 +32,7 @@ def is_hip():
|
||||
return triton.runtime.driver.active.get_current_target().backend == "hip"
|
||||
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope(
|
||||
BLOCK = 32
|
||||
|
||||
# # [TODO] work around shmem limit on MI3xx
|
||||
# if is_hip_ and kv_lora_rank >= 576:
|
||||
# if _is_hip and kv_lora_rank >= 576:
|
||||
# BLOCK = 16
|
||||
|
||||
qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
|
||||
@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope(
|
||||
|
||||
extra_kargs = {}
|
||||
num_stages = 2
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
@@ -6,8 +6,9 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
|
||||
@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
flashinfer_gemm_warpper = None
|
||||
@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
|
||||
@@ -23,10 +23,11 @@ from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
get_device_name,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool(
|
||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
)
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
@@ -46,7 +46,7 @@ if _is_cuda:
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
if _is_cuda or _is_rocm:
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
|
||||
@@ -679,7 +679,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
"num_stages": 2 if _is_hip else 4,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
@@ -688,7 +688,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
"num_stages": 2 if _is_hip else 4,
|
||||
}
|
||||
else:
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
||||
@@ -698,7 +698,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_ else 3,
|
||||
"num_stages": 2 if _is_hip else 3,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
@@ -976,7 +976,7 @@ def fused_experts_impl(
|
||||
if (
|
||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
||||
or block_shape is not None
|
||||
or (is_hip_ and get_bool_env_var("CK_MOE"))
|
||||
or (_is_hip and get_bool_env_var("CK_MOE"))
|
||||
):
|
||||
padded_size = 0
|
||||
|
||||
@@ -1131,7 +1131,7 @@ def fused_experts_impl(
|
||||
|
||||
if no_combine:
|
||||
pass
|
||||
elif is_hip_:
|
||||
elif _is_hip:
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
|
||||
@@ -27,9 +27,9 @@ else:
|
||||
|
||||
import logging
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
from aiter import ck_moe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
assert not no_combine, "unsupported"
|
||||
return ck_moe(
|
||||
x,
|
||||
@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 2.0
|
||||
|
||||
# this is needed for compressed-tensors only
|
||||
@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module):
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 0.5
|
||||
|
||||
self._load_per_channel_weight_scale(
|
||||
@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 2.0
|
||||
|
||||
self._load_per_tensor_weight_scale(
|
||||
|
||||
@@ -54,9 +54,9 @@ from sglang.srt.utils import (
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
from aiter.fused_moe_bf16_asm import asm_moe
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
||||
# Disable marlin for ROCm
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self.use_marlin = False
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# activation_scheme: dynamic
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
@@ -563,7 +563,7 @@ class Fp8MoEMethod:
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
if (
|
||||
is_hip_
|
||||
_is_hip
|
||||
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
|
||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||
w13_weight_scale1 = torch.nn.Parameter(
|
||||
@@ -630,7 +630,7 @@ class Fp8MoEMethod:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
@@ -667,7 +667,7 @@ class Fp8MoEMethod:
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
@@ -689,7 +689,7 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self.process_weights_hip_scale_padding(layer)
|
||||
return
|
||||
|
||||
@@ -721,7 +721,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@@ -771,7 +771,7 @@ class Fp8MoEMethod:
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self.process_weights_hip_scale_padding(layer)
|
||||
return
|
||||
|
||||
@@ -882,7 +882,7 @@ class Fp8MoEMethod:
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return asm_moe(
|
||||
@@ -895,7 +895,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_weight_scale1,
|
||||
activation=activation,
|
||||
)
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
|
||||
@@ -22,12 +22,12 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
import deep_gemm
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
@@ -157,7 +157,7 @@ def per_token_group_quant_fp8(
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
|
||||
fp8_min = -fp8_max
|
||||
@@ -332,7 +332,7 @@ def static_quant_fp8(
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
|
||||
fp8_min = -fp8_max
|
||||
@@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul(
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ from sglang.srt.utils import (
|
||||
|
||||
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
||||
|
||||
is_hip_ = is_hip()
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
_is_hip = is_hip()
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
|
||||
output = fp8_blockwise_scaled_mm(
|
||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||
)
|
||||
elif is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
elif _is_hip and get_bool_env_var("CK_MOE"):
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
@@ -142,7 +142,7 @@ def input_to_float8(
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
fp8_max = finfo.max
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
scale = fp8_max / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||
|
||||
@@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class W8A8Fp8Config(QuantizationConfig):
|
||||
"""Config class for W8A8 FP8 Quantization.
|
||||
@@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale.detach()
|
||||
if is_hip():
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
|
||||
@@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
else:
|
||||
capture_bs = list(range(1, 33))
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
capture_bs += [i * 8 for i in range(21, 33)]
|
||||
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.utils import add_prefix, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class DeepseekModelNextN(nn.Module):
|
||||
@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import bmm_fp8
|
||||
@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
if no_absorb():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
if (
|
||||
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
and forward_batch.forward_mode.is_decode()
|
||||
@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
|
||||
def get_embed_and_head(self):
|
||||
|
||||
@@ -72,13 +72,17 @@ show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
|
||||
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
||||
def is_hip() -> bool:
|
||||
"""Return whether it is HIP on the AMD ROCm platform."""
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def is_rocm() -> bool:
|
||||
return torch.cuda.is_available() and torch.version.hip
|
||||
|
||||
|
||||
def is_cuda():
|
||||
return hasattr(torch, "cuda") and torch.version.cuda is not None
|
||||
return torch.cuda.is_available() and torch.version.cuda
|
||||
|
||||
|
||||
def is_cuda_alike():
|
||||
@@ -100,11 +104,11 @@ def is_flashinfer_available():
|
||||
"""
|
||||
if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
|
||||
return False
|
||||
return torch.cuda.is_available() and torch.version.cuda
|
||||
return is_cuda()
|
||||
|
||||
|
||||
def is_cuda_available():
|
||||
return torch.cuda.is_available() and torch.version.cuda
|
||||
return is_cuda()
|
||||
|
||||
|
||||
def enable_show_time_cost():
|
||||
|
||||
Reference in New Issue
Block a user