Fix typos and unify size(s)/stride(s) API calls (#8799)

This commit is contained in:
triple-mu
2025-08-08 15:18:08 +08:00
committed by GitHub
parent 9c7e392465
commit 444013585d
6 changed files with 34 additions and 34 deletions

View File

@@ -105,10 +105,10 @@ typename T::Fmha::Arguments args_from_options(
hw_info.device_id = q_nope.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
int batches = q_nope.sizes()[0];
int page_count_per_seq = page_table.sizes()[1];
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
int page_size = kv_c_and_k_pe_cache.sizes()[1];
int batches = q_nope.size(0);
int page_count_per_seq = page_table.size(1);
int page_count_total = kv_c_and_k_pe_cache.size(0);
int page_size = kv_c_and_k_pe_cache.size(1);
int max_seq_len = page_size * page_count_per_seq;
using TileShapeH = typename T::TileShapeH;
using TileShapeD = typename T::TileShapeD;
@@ -220,7 +220,7 @@ void cutlass_mla_decode(
auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
const int page_size = kv_c_and_k_pe_cache.size(1);
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)