Fix typos and unify size(s)/stride(s) API calls (#8799)
This commit is contained in:
@@ -105,10 +105,10 @@ typename T::Fmha::Arguments args_from_options(
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int batches = q_nope.size(0);
|
||||
int page_count_per_seq = page_table.size(1);
|
||||
int page_count_total = kv_c_and_k_pe_cache.size(0);
|
||||
int page_size = kv_c_and_k_pe_cache.size(1);
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
@@ -220,7 +220,7 @@ void cutlass_mla_decode(
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
const int page_size = kv_c_and_k_pe_cache.size(1);
|
||||
|
||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
||||
|
||||
Reference in New Issue
Block a user