unify is_cuda and is_hip (#4321)
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
class CustomOp(nn.Module):
|
class CustomOp(nn.Module):
|
||||||
@@ -34,7 +36,7 @@ class CustomOp(nn.Module):
|
|||||||
def dispatch_forward(self):
|
def dispatch_forward(self):
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
return self.forward_cuda
|
return self.forward_cuda
|
||||||
elif _is_rocm:
|
elif _is_hip:
|
||||||
return self.forward_hip
|
return self.forward_hip
|
||||||
else:
|
else:
|
||||||
return self.forward_native
|
return self.forward_native
|
||||||
|
|||||||
@@ -22,15 +22,16 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_cuda = is_cuda()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if is_cuda():
|
if _is_cuda:
|
||||||
try:
|
try:
|
||||||
import pynvml
|
import pynvml
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Failed to import pynvml with %r", e)
|
logger.warning("Failed to import pynvml with %r", e)
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
try:
|
try:
|
||||||
from amdsmi import (
|
from amdsmi import (
|
||||||
AmdSmiException,
|
AmdSmiException,
|
||||||
@@ -43,7 +44,7 @@ if is_hip_:
|
|||||||
logger.warning("Failed to import amdsmi with %r", e)
|
logger.warning("Failed to import amdsmi with %r", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||||
# Use vLLM custom allreduce
|
# Use vLLM custom allreduce
|
||||||
ops.meta_size()
|
ops.meta_size()
|
||||||
else:
|
else:
|
||||||
@@ -63,7 +64,7 @@ _R = TypeVar("_R")
|
|||||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
try:
|
try:
|
||||||
amdsmi_init()
|
amdsmi_init()
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
@@ -81,7 +82,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
|||||||
|
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
"""
|
"""
|
||||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||||
"""
|
"""
|
||||||
@@ -145,7 +146,7 @@ def is_weak_contiguous(inp: torch.Tensor):
|
|||||||
class CustomAllreduce:
|
class CustomAllreduce:
|
||||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||||
_MAX_CAR_SIZE = 8192 * 1024
|
_MAX_CAR_SIZE = 8192 * 1024
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# crossover is at 16MB buffer size for ROCm
|
# crossover is at 16MB buffer size for ROCm
|
||||||
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
||||||
|
|
||||||
@@ -229,7 +230,7 @@ class CustomAllreduce:
|
|||||||
# test nvlink first, this will filter out most of the cases
|
# test nvlink first, this will filter out most of the cases
|
||||||
# where custom allreduce is not supported
|
# where custom allreduce is not supported
|
||||||
# this checks hardware and driver support for NVLink
|
# this checks hardware and driver support for NVLink
|
||||||
if is_cuda() or is_hip_:
|
if _is_cuda or _is_hip:
|
||||||
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
||||||
|
|
||||||
if world_size > 2 and not full_nvlink:
|
if world_size > 2 and not full_nvlink:
|
||||||
@@ -243,7 +244,7 @@ class CustomAllreduce:
|
|||||||
# this is expensive to compute at the first time
|
# this is expensive to compute at the first time
|
||||||
# then we cache the result
|
# then we cache the result
|
||||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||||
if not is_hip_ and not _can_p2p(rank, world_size):
|
if not _is_hip and not _can_p2p(rank, world_size):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Custom allreduce is disabled because your platform lacks "
|
"Custom allreduce is disabled because your platform lacks "
|
||||||
"GPU P2P capability or P2P test failed. To silence this "
|
"GPU P2P capability or P2P test failed. To silence this "
|
||||||
@@ -256,7 +257,7 @@ class CustomAllreduce:
|
|||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.full_nvlink = full_nvlink
|
self.full_nvlink = full_nvlink
|
||||||
|
|
||||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||||
# Buffers memory are owned by this Python class and passed to C++.
|
# Buffers memory are owned by this Python class and passed to C++.
|
||||||
# Meta data composes of two parts: meta data for synchronization and a
|
# Meta data composes of two parts: meta data for synchronization and a
|
||||||
# temporary buffer for storing intermediate allreduce results.
|
# temporary buffer for storing intermediate allreduce results.
|
||||||
@@ -279,7 +280,7 @@ class CustomAllreduce:
|
|||||||
)
|
)
|
||||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||||
else:
|
else:
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# meta data buffers need to be "uncached" for signal on MI200
|
# meta data buffers need to be "uncached" for signal on MI200
|
||||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||||
self.buffer = torch.empty(
|
self.buffer = torch.empty(
|
||||||
@@ -418,7 +419,7 @@ class CustomAllreduce:
|
|||||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||||
|
|
||||||
def register_graph_buffers(self):
|
def register_graph_buffers(self):
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||||
@@ -454,12 +455,12 @@ class CustomAllreduce:
|
|||||||
return False
|
return False
|
||||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||||
# little performance improvement over NCCL.
|
# little performance improvement over NCCL.
|
||||||
if ops.use_vllm_custom_allreduce and not is_hip_:
|
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||||
if self.world_size == 2 or self.full_nvlink:
|
if self.world_size == 2 or self.full_nvlink:
|
||||||
return inp_size < self.max_size
|
return inp_size < self.max_size
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
if self.full_nvlink:
|
if self.full_nvlink:
|
||||||
if self.world_size == 8:
|
if self.world_size == 8:
|
||||||
if self.MSCCL:
|
if self.MSCCL:
|
||||||
@@ -532,7 +533,7 @@ class CustomAllreduce:
|
|||||||
return None
|
return None
|
||||||
if self._IS_CAPTURING:
|
if self._IS_CAPTURING:
|
||||||
if torch.cuda.is_current_stream_capturing():
|
if torch.cuda.is_current_stream_capturing():
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
return self.all_reduce_reg(input)
|
return self.all_reduce_reg(input)
|
||||||
else:
|
else:
|
||||||
return self.all_reduce(input, registered=True)
|
return self.all_reduce(input, registered=True)
|
||||||
@@ -541,7 +542,7 @@ class CustomAllreduce:
|
|||||||
# allreduce is out-of-place.
|
# allreduce is out-of-place.
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
else:
|
else:
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# note: outside of cuda graph context,
|
# note: outside of cuda graph context,
|
||||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||||
# be small(<=1% of overall latency) compared to the performance
|
# be small(<=1% of overall latency) compared to the performance
|
||||||
@@ -556,7 +557,7 @@ class CustomAllreduce:
|
|||||||
if ops.use_vllm_custom_allreduce:
|
if ops.use_vllm_custom_allreduce:
|
||||||
self.free_shared_buffer(self.meta_ptrs)
|
self.free_shared_buffer(self.meta_ptrs)
|
||||||
self.free_shared_buffer(self.buffer_ptrs)
|
self.free_shared_buffer(self.buffer_ptrs)
|
||||||
elif is_cuda():
|
elif _is_cuda:
|
||||||
self.free_shared_buffer(self.buffer_ptrs)
|
self.free_shared_buffer(self.buffer_ptrs)
|
||||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
|
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
|
||||||
self.free_shared_buffer(self.barrier_in_ptrs)
|
self.free_shared_buffer(self.barrier_in_ptrs)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ def _decode_att_m_fwd(
|
|||||||
):
|
):
|
||||||
BLOCK = 64
|
BLOCK = 64
|
||||||
# [TODO] work around SGPR limit on MI3xx
|
# [TODO] work around SGPR limit on MI3xx
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
BLOCK = 8
|
BLOCK = 8
|
||||||
NUM_KV_SPLITS = num_kv_splits
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
@@ -195,7 +195,7 @@ def _decode_att_m_fwd(
|
|||||||
num_warps = 4
|
num_warps = 4
|
||||||
else:
|
else:
|
||||||
num_warps = 2
|
num_warps = 2
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
num_warps = 1
|
num_warps = 1
|
||||||
|
|
||||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
# [TODO] work around shmem limit on MI3xx
|
# [TODO] work around shmem limit on MI3xx
|
||||||
if is_hip_ and Lk >= 576:
|
if _is_hip and Lk >= 576:
|
||||||
BLOCK = 16
|
BLOCK = 16
|
||||||
|
|
||||||
if Lk == 576:
|
if Lk == 576:
|
||||||
@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
num_stages = 2
|
num_stages = 2
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd(
|
|||||||
NUM_KV_SPLITS = num_kv_splits
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available()
|
|||||||
if is_cuda_available:
|
if is_cuda_available:
|
||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||||
REDUCE_TRITON_TYPE = tl.float32
|
REDUCE_TRITON_TYPE = tl.float32
|
||||||
@@ -1032,7 +1032,7 @@ def extend_attention_fwd(
|
|||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
BLOCK_M, BLOCK_N = (64, 64)
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
num_warps = 4
|
num_warps = 4
|
||||||
|
|
||||||
@@ -1062,7 +1062,7 @@ def extend_attention_fwd(
|
|||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available()
|
|||||||
if is_cuda_available:
|
if is_cuda_available:
|
||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -330,7 +330,7 @@ def extend_attention_fwd(
|
|||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
BLOCK_M, BLOCK_N = (64, 64)
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
num_warps = 4
|
num_warps = 4
|
||||||
|
|
||||||
@@ -364,7 +364,7 @@ def extend_attention_fwd(
|
|||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
@@ -403,7 +403,7 @@ def extend_attention_fwd(
|
|||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||||
STORE_TRANSPOSE=is_hip_,
|
STORE_TRANSPOSE=_is_hip,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
**extra_kargs,
|
**extra_kargs,
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def is_hip():
|
|||||||
return triton.runtime.driver.active.get_current_target().backend == "hip"
|
return triton.runtime.driver.active.get_current_target().backend == "hip"
|
||||||
|
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope(
|
|||||||
BLOCK = 32
|
BLOCK = 32
|
||||||
|
|
||||||
# # [TODO] work around shmem limit on MI3xx
|
# # [TODO] work around shmem limit on MI3xx
|
||||||
# if is_hip_ and kv_lora_rank >= 576:
|
# if _is_hip and kv_lora_rank >= 576:
|
||||||
# BLOCK = 16
|
# BLOCK = 16
|
||||||
|
|
||||||
qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
|
qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
|
||||||
@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope(
|
|||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
num_stages = 2
|
num_stages = 2
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
_is_cuda = is_cuda()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
class GroupedGemmRunner(torch.nn.Module):
|
class GroupedGemmRunner(torch.nn.Module):
|
||||||
flashinfer_gemm_warpper = None
|
flashinfer_gemm_warpper = None
|
||||||
@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# If rocm, use float8_e4m3fnuz as dtype
|
# If rocm, use float8_e4m3fnuz as dtype
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -23,10 +23,11 @@ from sglang.srt.utils import (
|
|||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_device_name,
|
get_device_name,
|
||||||
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool(
|
|||||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
_is_cuda = is_cuda()
|
||||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
@@ -46,7 +46,7 @@ if _is_cuda:
|
|||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_cuda or _is_rocm:
|
if _is_cuda or _is_hip:
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
|
|
||||||
|
|
||||||
@@ -679,7 +679,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 2 if is_hip_ else 4,
|
"num_stages": 2 if _is_hip else 4,
|
||||||
}
|
}
|
||||||
if M <= E:
|
if M <= E:
|
||||||
config = {
|
config = {
|
||||||
@@ -688,7 +688,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2 if is_hip_ else 4,
|
"num_stages": 2 if _is_hip else 4,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
||||||
@@ -698,7 +698,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": block_shape[1],
|
"BLOCK_SIZE_K": block_shape[1],
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2 if is_hip_ else 3,
|
"num_stages": 2 if _is_hip else 3,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config = {
|
config = {
|
||||||
@@ -976,7 +976,7 @@ def fused_experts_impl(
|
|||||||
if (
|
if (
|
||||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
not (use_fp8_w8a8 or use_int8_w8a8)
|
||||||
or block_shape is not None
|
or block_shape is not None
|
||||||
or (is_hip_ and get_bool_env_var("CK_MOE"))
|
or (_is_hip and get_bool_env_var("CK_MOE"))
|
||||||
):
|
):
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
|
|
||||||
@@ -1131,7 +1131,7 @@ def fused_experts_impl(
|
|||||||
|
|
||||||
if no_combine:
|
if no_combine:
|
||||||
pass
|
pass
|
||||||
elif is_hip_:
|
elif _is_hip:
|
||||||
ops.moe_sum(
|
ops.moe_sum(
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ else:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
from aiter import ck_moe
|
from aiter import ck_moe
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
permute_weight(layer.w13_weight.data),
|
permute_weight(layer.w13_weight.data),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||||
assert not no_combine, "unsupported"
|
assert not no_combine, "unsupported"
|
||||||
return ck_moe(
|
return ck_moe(
|
||||||
x,
|
x,
|
||||||
@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# Case input scale: input_scale loading is only supported for fp8
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
||||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
loaded_weight = loaded_weight * 2.0
|
loaded_weight = loaded_weight * 2.0
|
||||||
|
|
||||||
# this is needed for compressed-tensors only
|
# this is needed for compressed-tensors only
|
||||||
@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
quant_method = getattr(param, "quant_method", None)
|
quant_method = getattr(param, "quant_method", None)
|
||||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
||||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
loaded_weight = loaded_weight * 0.5
|
loaded_weight = loaded_weight * 0.5
|
||||||
|
|
||||||
self._load_per_channel_weight_scale(
|
self._load_per_channel_weight_scale(
|
||||||
@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
||||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
loaded_weight = loaded_weight * 2.0
|
loaded_weight = loaded_weight * 2.0
|
||||||
|
|
||||||
self._load_per_tensor_weight_scale(
|
self._load_per_tensor_weight_scale(
|
||||||
|
|||||||
@@ -54,9 +54,9 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe
|
from aiter.fused_moe_bf16_asm import asm_moe
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
||||||
# Disable marlin for ROCm
|
# Disable marlin for ROCm
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# activation_scheme: dynamic
|
# activation_scheme: dynamic
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
weight_scale = layer.weight_scale
|
weight_scale = layer.weight_scale
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@@ -563,7 +563,7 @@ class Fp8MoEMethod:
|
|||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
is_hip_
|
_is_hip
|
||||||
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
|
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
|
||||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||||
w13_weight_scale1 = torch.nn.Parameter(
|
w13_weight_scale1 = torch.nn.Parameter(
|
||||||
@@ -630,7 +630,7 @@ class Fp8MoEMethod:
|
|||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# activation_scheme: dynamic
|
# activation_scheme: dynamic
|
||||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=layer.w13_weight,
|
weight=layer.w13_weight,
|
||||||
@@ -667,7 +667,7 @@ class Fp8MoEMethod:
|
|||||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
|
|
||||||
@@ -689,7 +689,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
self.process_weights_hip_scale_padding(layer)
|
self.process_weights_hip_scale_padding(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -721,7 +721,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
# Normalize the weights and scales
|
# Normalize the weights and scales
|
||||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
@@ -771,7 +771,7 @@ class Fp8MoEMethod:
|
|||||||
max_w13_scales, requires_grad=False
|
max_w13_scales, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
self.process_weights_hip_scale_padding(layer)
|
self.process_weights_hip_scale_padding(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -882,7 +882,7 @@ class Fp8MoEMethod:
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
return asm_moe(
|
return asm_moe(
|
||||||
@@ -895,7 +895,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_weight_scale1,
|
layer.w2_weight_scale1,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
)
|
)
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||||
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
||||||
assert (
|
assert (
|
||||||
activation == "silu"
|
activation == "silu"
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
|
||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
_is_cuda = is_cuda()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||||
@@ -157,7 +157,7 @@ def per_token_group_quant_fp8(
|
|||||||
finfo = torch.finfo(dtype)
|
finfo = torch.finfo(dtype)
|
||||||
fp8_max = finfo.max
|
fp8_max = finfo.max
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
fp8_max = 224.0
|
fp8_max = 224.0
|
||||||
|
|
||||||
fp8_min = -fp8_max
|
fp8_min = -fp8_max
|
||||||
@@ -332,7 +332,7 @@ def static_quant_fp8(
|
|||||||
finfo = torch.finfo(dtype)
|
finfo = torch.finfo(dtype)
|
||||||
fp8_max = finfo.max
|
fp8_max = finfo.max
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
fp8_max = 224.0
|
fp8_max = 224.0
|
||||||
|
|
||||||
fp8_min = -fp8_max
|
fp8_min = -fp8_max
|
||||||
@@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul(
|
|||||||
else:
|
else:
|
||||||
kernel = (
|
kernel = (
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||||
else _w8a8_block_fp8_matmul
|
else _w8a8_block_fp8_matmul
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||||
from aiter import gemm_a8w8_blockscale
|
from aiter import gemm_a8w8_blockscale
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
output = fp8_blockwise_scaled_mm(
|
output = fp8_blockwise_scaled_mm(
|
||||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||||
)
|
)
|
||||||
elif is_hip_ and get_bool_env_var("CK_MOE"):
|
elif _is_hip and get_bool_env_var("CK_MOE"):
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=False
|
input_2d, block_size[1], column_major_scales=False
|
||||||
)
|
)
|
||||||
@@ -142,7 +142,7 @@ def input_to_float8(
|
|||||||
min_val, max_val = x.aminmax()
|
min_val, max_val = x.aminmax()
|
||||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
fp8_max = finfo.max
|
fp8_max = finfo.max
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
fp8_max = 224.0
|
fp8_max = 224.0
|
||||||
scale = fp8_max / amax
|
scale = fp8_max / amax
|
||||||
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
class W8A8Fp8Config(QuantizationConfig):
|
class W8A8Fp8Config(QuantizationConfig):
|
||||||
"""Config class for W8A8 FP8 Quantization.
|
"""Config class for W8A8 FP8 Quantization.
|
||||||
@@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
weight_scale = layer.weight_scale.detach()
|
weight_scale = layer.weight_scale.detach()
|
||||||
if is_hip():
|
if _is_hip:
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight, weight_scale=weight_scale
|
weight=weight, weight_scale=weight_scale
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
else:
|
else:
|
||||||
capture_bs = list(range(1, 33))
|
capture_bs = list(range(1, 33))
|
||||||
|
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
capture_bs += [i * 8 for i in range(21, 33)]
|
capture_bs += [i * 8 for i in range(21, 33)]
|
||||||
|
|
||||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||||
from sglang.srt.utils import add_prefix, is_hip
|
from sglang.srt.utils import add_prefix, is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
class DeepseekModelNextN(nn.Module):
|
class DeepseekModelNextN(nn.Module):
|
||||||
@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
weight_block_size = self.quant_config.weight_block_size
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
if weight_block_size is not None:
|
if weight_block_size is not None:
|
||||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||||
@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
and self_attn.w_scale is None
|
and self_attn.w_scale is None
|
||||||
):
|
):
|
||||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
self_attn.w_scale *= 2.0
|
self_attn.w_scale *= 2.0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if is_cuda_available():
|
if is_cuda_available():
|
||||||
from sgl_kernel import bmm_fp8
|
from sgl_kernel import bmm_fp8
|
||||||
@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
if no_absorb():
|
if no_absorb():
|
||||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||||
else:
|
else:
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
if (
|
if (
|
||||||
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||||
and forward_batch.forward_mode.is_decode()
|
and forward_batch.forward_mode.is_decode()
|
||||||
@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
weight_block_size = self.quant_config.weight_block_size
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
if weight_block_size is not None:
|
if weight_block_size is not None:
|
||||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||||
@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
and self_attn.w_scale is None
|
and self_attn.w_scale is None
|
||||||
):
|
):
|
||||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
self_attn.w_scale *= 2.0
|
self_attn.w_scale *= 2.0
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
|
|||||||
@@ -72,13 +72,17 @@ show_time_cost = False
|
|||||||
time_infos = {}
|
time_infos = {}
|
||||||
|
|
||||||
|
|
||||||
|
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
||||||
def is_hip() -> bool:
|
def is_hip() -> bool:
|
||||||
"""Return whether it is HIP on the AMD ROCm platform."""
|
|
||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_rocm() -> bool:
|
||||||
|
return torch.cuda.is_available() and torch.version.hip
|
||||||
|
|
||||||
|
|
||||||
def is_cuda():
|
def is_cuda():
|
||||||
return hasattr(torch, "cuda") and torch.version.cuda is not None
|
return torch.cuda.is_available() and torch.version.cuda
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_alike():
|
def is_cuda_alike():
|
||||||
@@ -100,11 +104,11 @@ def is_flashinfer_available():
|
|||||||
"""
|
"""
|
||||||
if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
|
if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
|
||||||
return False
|
return False
|
||||||
return torch.cuda.is_available() and torch.version.cuda
|
return is_cuda()
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_available():
|
def is_cuda_available():
|
||||||
return torch.cuda.is_available() and torch.version.cuda
|
return is_cuda()
|
||||||
|
|
||||||
|
|
||||||
def enable_show_time_cost():
|
def enable_show_time_cost():
|
||||||
|
|||||||
Reference in New Issue
Block a user