diff --git a/python/pyproject.toml b/python/pyproject.toml index 973307bdc..9ef33e2f7 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -82,7 +82,7 @@ srt_hip = [ "sglang[runtime_common]", "torch", "petit_kernel==0.0.2", - "wave-lang==1.0.1", + "wave-lang==3.7.0", ] # https://docs.sglang.ai/platforms/cpu_server.html diff --git a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py index cb89697bd..c76bee9af 100644 --- a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py @@ -64,8 +64,7 @@ def get_wave_kernel( subs=hyperparams_0, canonicalize=True, run_bench=False, - use_buffer_load_ops=True, - use_buffer_store_ops=True, + use_buffer_ops=True, waves_per_eu=2, dynamic_symbols=dynamic_symbols_0, wave_runtime=True, @@ -77,8 +76,7 @@ def get_wave_kernel( subs=hyperparams_1, canonicalize=True, run_bench=False, - use_buffer_load_ops=False, - use_buffer_store_ops=False, + use_buffer_ops=False, waves_per_eu=4, dynamic_symbols=dynamic_symbols_1, wave_runtime=True, diff --git a/python/sglang/srt/layers/attention/wave_ops/extend_attention.py b/python/sglang/srt/layers/attention/wave_ops/extend_attention.py index 35a53d3e2..27e674db2 100644 --- a/python/sglang/srt/layers/attention/wave_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/wave_ops/extend_attention.py @@ -67,11 +67,9 @@ def get_wave_kernel( schedule=SchedulingType.NONE, use_scheduling_barriers=False, dynamic_symbols=dynamic_symbols, - use_buffer_load_ops=True, - use_buffer_store_ops=True, + use_buffer_ops=True, waves_per_eu=2, denorm_fp_math_f32="preserve-sign", - gpu_native_math_precision=True, wave_runtime=True, ) options = set_default_run_config(options)