add deepseekv3 and llama4
This commit is contained in:
@@ -252,19 +252,27 @@ def forward_prefill(
|
|||||||
updated_slot_mapping = attn_metadata.slot_mapping
|
updated_slot_mapping = attn_metadata.slot_mapping
|
||||||
if self.attn.kv_cache_dtype == 'int8':
|
if self.attn.kv_cache_dtype == 'int8':
|
||||||
key_cache_scale = kv_cache[1][0]
|
key_cache_scale = kv_cache[1][0]
|
||||||
mlu_ops.quant_to_paged_cache(key_value,
|
mlu_ops.quant_to_linear_cache(key_value,
|
||||||
|
None,
|
||||||
|
key_cache,
|
||||||
|
None,
|
||||||
|
key_cache_scale,
|
||||||
|
None,
|
||||||
|
attn_metadata.cu_seq_lens,
|
||||||
|
attn_metadata.max_seq_len,
|
||||||
|
True, None,
|
||||||
|
attn_metadata.batch_ids,
|
||||||
|
attn_metadata.slot_mapping_unpaged)
|
||||||
|
else:
|
||||||
|
mlu_ops.reshape_linear_cache(key_value,
|
||||||
None,
|
None,
|
||||||
key_cache,
|
key_cache,
|
||||||
None,
|
None,
|
||||||
key_cache_scale,
|
attn_metadata.cu_seq_lens,
|
||||||
None,
|
attn_metadata.max_seq_len,
|
||||||
updated_slot_mapping.flatten())
|
True, None,
|
||||||
else:
|
attn_metadata.batch_ids,
|
||||||
mlu_ops.reshape_paged_cache(key_value,
|
attn_metadata.slot_mapping_unpaged)
|
||||||
None,
|
|
||||||
key_cache,
|
|
||||||
None,
|
|
||||||
updated_slot_mapping.flatten())
|
|
||||||
'''
|
'''
|
||||||
==================
|
==================
|
||||||
End of MLU Hijack
|
End of MLU Hijack
|
||||||
|
|||||||
Reference in New Issue
Block a user