Minor update for ROCm variable style (#5562)
This commit is contained in:
@@ -20,7 +20,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
_is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkConfig(TypedDict):
|
class BenchmarkConfig(TypedDict):
|
||||||
@@ -112,8 +112,8 @@ def benchmark_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||||
|
|
||||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
@@ -204,7 +204,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
|
|||||||
# TODO(woosuk): Increase the search space and use a performance model to
|
# TODO(woosuk): Increase the search space and use a performance model to
|
||||||
# prune the search space.
|
# prune the search space.
|
||||||
configs: List[BenchmarkConfig] = []
|
configs: List[BenchmarkConfig] = []
|
||||||
if _is_hip_:
|
if _is_hip:
|
||||||
configs = get_rocm_configs_compute_bound()
|
configs = get_rocm_configs_compute_bound()
|
||||||
else:
|
else:
|
||||||
for num_stages in [2, 3, 4, 5]:
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
|
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
|
||||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
DTYPE_MAP = {
|
DTYPE_MAP = {
|
||||||
"float32": torch.float32,
|
"float32": torch.float32,
|
||||||
@@ -99,7 +99,7 @@ def w8a8_block_matmul(
|
|||||||
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
|
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
|
||||||
kernel = (
|
kernel = (
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||||
else _w8a8_block_fp8_matmul
|
else _w8a8_block_fp8_matmul
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -157,7 +157,7 @@ def get_rocm_configs_compute_bound():
|
|||||||
|
|
||||||
def get_configs_compute_bound():
|
def get_configs_compute_bound():
|
||||||
configs = []
|
configs = []
|
||||||
if is_hip_:
|
if _is_hip:
|
||||||
configs = get_rocm_configs_compute_bound()
|
configs = get_rocm_configs_compute_bound()
|
||||||
else:
|
else:
|
||||||
for num_stages in [2, 3, 4, 5]:
|
for num_stages in [2, 3, 4, 5]:
|
||||||
@@ -244,7 +244,7 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
|||||||
|
|
||||||
if input_type == "fp8":
|
if input_type == "fp8":
|
||||||
fp8_info = torch.finfo(
|
fp8_info = torch.finfo(
|
||||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
@@ -252,14 +252,14 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
|||||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||||
)
|
)
|
||||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
B_fp32 = (
|
B_fp32 = (
|
||||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||||
)
|
)
|
||||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
int8_info = torch.iinfo(torch.int8)
|
int8_info = torch.iinfo(torch.int8)
|
||||||
|
|||||||
Reference in New Issue
Block a user