Fix cuda graph mode in flashinfer attn backend (#10056)
This commit is contained in:
@@ -501,8 +501,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
window_left=layer.sliding_window_size,
|
window_left=layer.sliding_window_size,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
k_scale=layer.k_scale,
|
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
||||||
v_scale=layer.v_scale,
|
k_scale=layer.k_scale_float,
|
||||||
|
v_scale=layer.v_scale_float,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
causal = True
|
causal = True
|
||||||
@@ -580,8 +581,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=layer.logit_cap,
|
||||||
k_scale=layer.k_scale,
|
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
||||||
v_scale=layer.v_scale,
|
k_scale=layer.k_scale_float,
|
||||||
|
v_scale=layer.v_scale_float,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|||||||
Reference in New Issue
Block a user