Optimize rope in sgl kernel (#4267)
This commit is contained in:
@@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
positions = positions.int()
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
k_rope=key.view(key.shape[0], -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
pos_ids=positions,
|
||||
pos_ids=positions.long(),
|
||||
interleave=(not is_neox),
|
||||
cuda_stream=get_cuda_stream(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user