diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 03406ef86..a4aa39ad2 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -55,6 +55,7 @@ jobs: timeout-minutes: 20 run: | docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py + docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py mla-test-1-gpu-amd: if: github.event.pull_request.head.repo.fork == false && github.event.pull_request.draft == false diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 44a3cba8a..79092813a 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase): weight = layer.weight weight_scale = layer.weight_scale # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, @@ -624,56 +624,9 @@ class Fp8MoEMethod: def process_weights_after_loading(self, layer: Module) -> None: if get_bool_env_var("USE_INT4_WEIGHT"): - # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added - # INT4-FP8 (INT4 MoE Weight, FP8 Compute) - # Weight Permutation - layer.w13_weight = torch.nn.Parameter( - permute_weight(layer.w13_weight.data), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - permute_weight(layer.w2_weight.data), - requires_grad=False, - ) - torch.cuda.empty_cache() - - # INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale - # Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert. - # We won't do requant each expert's fp8 weight (not direct available), - # instead we adjust half of INT4 w13_weight_scale1 numbers - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_experts): - start = 0 - max_w13_scale_fp8 = max_w13_scales[expert_id] - for shard_id in range(2): - if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8: - int4_rescale = ( - layer.w13_weight_scale[expert_id][shard_id] - / max_w13_scale_fp8 - ) - layer.w13_weight_scale1[expert_id][ - start : start + shard_size - ] *= int4_rescale - start += shard_size - - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) - - # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling - # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post - for expert_id in range(layer.num_experts): - layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id] - layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id] + self.process_weights_hip_int4(layer) return - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( - padding_size, # Avoid circular import - ) - # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz @@ -710,6 +663,7 @@ class Fp8MoEMethod: layer.w2_weight.contiguous(), (16, 16) ) return + # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) @@ -736,32 +690,7 @@ class Fp8MoEMethod: layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if is_hip_: - if get_bool_env_var("CK_MOE"): - layer.w13_weight = torch.nn.Parameter( - permute_weight(layer.w13_weight.data), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - permute_weight(layer.w2_weight.data), - requires_grad=False, - ) - torch.cuda.empty_cache() - # ROCm (CK_MOE): using column-wise scaling - layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) - layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) - elif get_bool_env_var("MOE_PADDING"): - # If ROCm, apply weight padding (min. Mem channel contention) only if set - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + self.process_weights_hip_scale_padding(layer) return # If checkpoint is fp8, we need to handle that the @@ -843,34 +772,84 @@ class Fp8MoEMethod: ) if is_hip_: - if get_bool_env_var("CK_MOE"): - layer.w13_weight = torch.nn.Parameter( - permute_weight(layer.w13_weight.data), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - permute_weight(layer.w2_weight.data), - requires_grad=False, - ) - torch.cuda.empty_cache() - # ROCm (CK_MOE): using column-wise scaling - layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) - layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) - elif get_bool_env_var("MOE_PADDING"): - # If ROCm, apply weight padding (min. Mem channel contention) only if set - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + self.process_weights_hip_scale_padding(layer) return + def process_weights_hip_int4(self, layer: Module): + # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added + # INT4-FP8 (INT4 MoE Weight, FP8 Compute) + # Weight Permutation + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + + # INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale + # Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert. + # We won't do requant each expert's fp8 weight (not direct available), + # instead we adjust half of INT4 w13_weight_scale1 numbers + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + max_w13_scale_fp8 = max_w13_scales[expert_id] + for shard_id in range(2): + if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8: + int4_rescale = ( + layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8 + ) + layer.w13_weight_scale1[expert_id][ + start : start + shard_size + ] *= int4_rescale + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + + # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling + # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post + for expert_id in range(layer.num_experts): + layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id] + layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id] + + def process_weights_hip_scale_padding(self, layer: Module, padding_size: int): + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + padding_size, # Avoid circular import + ) + + if get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + # ROCm (CK_MOE): using column-wise scaling + layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) + layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) + elif get_bool_env_var("MOE_PADDING"): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + def apply( self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index feaae26f6..e53b971be 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,8 +1,6 @@ -import os from typing import List, Optional, Tuple import torch -from packaging.version import Version from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, @@ -13,18 +11,17 @@ from sglang.srt.utils import ( get_bool_env_var, get_cuda_version, get_device_capability, + is_cuda, is_hip, ) -use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get( - "USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False -) +use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") is_hip_ = is_hip() if is_hip_ and get_bool_env_var("CK_MOE"): from aiter import gemm_a8w8_blockscale -_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_cuda = is_cuda() if _is_cuda: from sgl_kernel import fp8_blockwise_scaled_mm @@ -73,7 +70,7 @@ def normalize_e4m3fn_to_e4m3fnuz( def cutlass_block_fp8_supported() -> bool: - if os.environ.get("SUPPORT_CUTLASS_BLOCK_FP8") is None: + if get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"): return False if _is_cuda: major, minor = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 08b487262..ebc148feb 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module): quant_method = None if quant_config is not None: quant_method = quant_config.get_quant_method(self, prefix=prefix) - print("quant_method", quant_method) if quant_method is None: quant_method = UnquantizedEmbeddingMethod() diff --git a/test/srt/models/test_qwen_models.py b/test/srt/models/test_qwen_models.py index 20bf4c689..d2a418ccf 100644 --- a/test/srt/models/test_qwen_models.py +++ b/test/srt/models/test_qwen_models.py @@ -69,7 +69,7 @@ class TestQwen2FP8(unittest.TestCase): ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.79) + self.assertGreater(metrics["accuracy"], 0.78) if __name__ == "__main__":