Optimize rope in sgl kernel (#4267)
This commit is contained in:
@@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache(
|
|||||||
static_cast<c_type*>(q_rope.data_ptr()),
|
static_cast<c_type*>(q_rope.data_ptr()),
|
||||||
static_cast<c_type*>(k_rope.data_ptr()),
|
static_cast<c_type*>(k_rope.data_ptr()),
|
||||||
static_cast<float*>(cos_sin_cache.data_ptr()),
|
static_cast<float*>(cos_sin_cache.data_ptr()),
|
||||||
static_cast<int32_t*>(pos_ids.data_ptr()),
|
static_cast<int64_t*>(pos_ids.data_ptr()),
|
||||||
nnz,
|
nnz,
|
||||||
num_qo_heads,
|
num_qo_heads,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
|
|||||||
@@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace(
|
|||||||
if cos_sin_cache.dtype != torch.float32:
|
if cos_sin_cache.dtype != torch.float32:
|
||||||
raise ValueError("cos_sin_cache should be float32")
|
raise ValueError("cos_sin_cache should be float32")
|
||||||
|
|
||||||
positions = positions.int()
|
|
||||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
||||||
q=query.view(query.shape[0], -1, head_size),
|
q=query.view(query.shape[0], -1, head_size),
|
||||||
k=key.view(key.shape[0], -1, head_size),
|
k=key.view(key.shape[0], -1, head_size),
|
||||||
q_rope=query.view(query.shape[0], -1, head_size),
|
q_rope=query.view(query.shape[0], -1, head_size),
|
||||||
k_rope=key.view(key.shape[0], -1, head_size),
|
k_rope=key.view(key.shape[0], -1, head_size),
|
||||||
cos_sin_cache=cos_sin_cache,
|
cos_sin_cache=cos_sin_cache,
|
||||||
pos_ids=positions,
|
pos_ids=positions.long(),
|
||||||
interleave=(not is_neox),
|
interleave=(not is_neox),
|
||||||
cuda_stream=get_cuda_stream(),
|
cuda_stream=get_cuda_stream(),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user