Files
2026-04-02 04:55:00 +00:00

42 lines
1.2 KiB
Python

from typing import Union
import torch
from torch import nn
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
def get_sin_cos_mrope(rotary_emb: MRotaryEmbedding, positions: Union[torch.Tensor], use_fuse: bool = True):
#ref: MRotaryEmbedding forward_native
# odsp
if use_fuse:
return torch.vacc.mrope_get_sin_cos(
rotary_emb.sin_cache,
rotary_emb.cos_cache,
positions,
rotary_emb.mrope_section,
rotary_emb.mrope_interleaved
)
cos = rotary_emb.cos_cache[positions]
sin = rotary_emb.sin_cache[positions]
if rotary_emb.mrope_interleaved:
cos_cache = apply_interleaved_rope(cos, rotary_emb.mrope_section)
sin_cache = apply_interleaved_rope(sin, rotary_emb.mrope_section)
else:
cos_cache = torch.cat([
m[i] for i, m in enumerate(
cos.split(rotary_emb.mrope_section, dim=-1))
],
dim=-1)
sin_cache = torch.cat([
m[i] for i, m in enumerate(
sin.split(rotary_emb.mrope_section, dim=-1))
],
dim=-1)
return cos_cache, sin_cache