diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 6e3418808..a5b207c77 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -501,8 +501,9 @@ class FlashInferAttnBackend(AttentionBackend): sm_scale=layer.scaling, window_left=layer.sliding_window_size, logits_soft_cap=logits_soft_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, + # Must use _float to avoid device-to-host copy that breaks cuda graph capture. + k_scale=layer.k_scale_float, + v_scale=layer.v_scale_float, ) else: causal = True @@ -580,8 +581,9 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, + # Must use _float to avoid device-to-host copy that breaks cuda graph capture. + k_scale=layer.k_scale_float, + v_scale=layer.v_scale_float, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim)