[chore] Remove obsolete comments
This commit is contained in:
@@ -73,7 +73,6 @@ class KunlunOps:
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV1"""
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
@@ -116,7 +115,6 @@ class KunlunOps:
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV2"""
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
@@ -221,17 +219,6 @@ class KunlunOps:
|
||||
num_heads = query_x.shape[1] // head_size
|
||||
num_kv_heads = key_x.shape[1] // head_size
|
||||
|
||||
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
|
||||
# query_x = query_x.view(num_tokens, num_heads, head_size)
|
||||
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
|
||||
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
|
||||
|
||||
# # Ensure shapes are correct
|
||||
# assert query_x.shape == (num_tokens, num_heads, head_size), \
|
||||
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
|
||||
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
|
||||
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
|
||||
)
|
||||
@@ -239,8 +226,6 @@ class KunlunOps:
|
||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||
|
||||
# query.data = query_x
|
||||
# key.data = key_x
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@@ -290,7 +275,6 @@ class KunlunOps:
|
||||
kv_cache_dtype,
|
||||
):
|
||||
"""reshape_and_cache"""
|
||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -271,13 +271,6 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
if self.context_lens_tensor is None
|
||||
else self.context_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
# for prefix cache, block table only contains blocks that hit
|
||||
# if self.block_tables is None:
|
||||
# block_tables = None
|
||||
# elif self.block_tables.shape[1] == 0:
|
||||
# block_tables = self.block_tables[:self.num_prefills]
|
||||
# else:
|
||||
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
|
||||
|
||||
block_tables = (
|
||||
None
|
||||
@@ -442,7 +435,6 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
if inter_data.prefix_cache_hit:
|
||||
assert context_len != 0
|
||||
assert context_len % self.block_size == 0
|
||||
# block_table = block_tables[seq_id]
|
||||
block_table = block_tables[seq_id][: context_len // self.block_size]
|
||||
elif (not is_prompt) and block_tables is not None:
|
||||
if curr_sliding_window_block == 0:
|
||||
@@ -483,7 +475,6 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
query_start_loc, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
attn_meta.query_start_loc_host = query_start_loc_host
|
||||
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
|
||||
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
|
||||
|
||||
# If kv cache is included and there is a hit
|
||||
|
||||
@@ -516,10 +516,6 @@ def _apply_top_k_top_p(
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
# logits = torch.empty_like(logits_sort).scatter_(dim=-1,
|
||||
# index=logits_idx,
|
||||
# src=logits_sort)
|
||||
return logits_sort, logits_idx
|
||||
|
||||
|
||||
@@ -883,7 +879,6 @@ def _sample_with_torch(
|
||||
seq_groups=seq_groups_arg,
|
||||
)
|
||||
if logits_idx is not None:
|
||||
# multinomial_samples[sampling_type] = logits_idx[:, result_idx[:][0]]
|
||||
token_ids = logits_idx[long_sample_indices].gather(
|
||||
dim=1, index=result_idx.to(logits_idx.device)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user