fix page first per layer pf2lf kernel (#8915)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -34,6 +34,7 @@ def transfer_kv_per_layer_pf_lf(
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
@@ -46,6 +47,7 @@ def transfer_kv_per_layer_pf_lf(
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
layer_id,
|
||||
item_size,
|
||||
src_layout_dim,
|
||||
block_quota,
|
||||
@@ -144,6 +146,7 @@ def transfer_kv_per_layer_mla_pf_lf(
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
@@ -154,6 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf(
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
layer_id,
|
||||
item_size,
|
||||
src_layout_dim,
|
||||
block_quota,
|
||||
|
||||
Reference in New Issue
Block a user