Revert "Fix nightly-test CI" (#4065)
This commit is contained in:
@@ -427,10 +427,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
self._to_dtype(
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
q.dtype,
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
@@ -472,9 +469,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
self._to_dtype(
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), q.dtype
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=layer.k_scale,
|
||||
@@ -483,12 +478,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def _to_dtype(self, kv_tuple, dtype):
|
||||
if kv_tuple[0].dtype != dtype:
|
||||
return tuple(t.to(dtype) for t in kv_tuple)
|
||||
else:
|
||||
return kv_tuple
|
||||
|
||||
def _get_wrapper_idx(self, layer: RadixAttention):
|
||||
if self.num_wrappers == 1:
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user