Optimize rope in sgl kernel (#4267)

This commit is contained in:
Lianmin Zheng
2025-03-10 10:07:45 -07:00
committed by GitHub
parent 3d56585a97
commit cf0ccd406e
2 changed files with 2 additions and 3 deletions

View File

@@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache(
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,