from typing import Union import torch from torch import nn from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope def get_sin_cos_mrope(rotary_emb: MRotaryEmbedding, positions: Union[torch.Tensor], use_fuse: bool = True): #ref: MRotaryEmbedding forward_native # odsp if use_fuse: return torch.vacc.mrope_get_sin_cos( rotary_emb.sin_cache, rotary_emb.cos_cache, positions, rotary_emb.mrope_section, rotary_emb.mrope_interleaved ) cos = rotary_emb.cos_cache[positions] sin = rotary_emb.sin_cache[positions] if rotary_emb.mrope_interleaved: cos_cache = apply_interleaved_rope(cos, rotary_emb.mrope_section) sin_cache = apply_interleaved_rope(sin, rotary_emb.mrope_section) else: cos_cache = torch.cat([ m[i] for i, m in enumerate( cos.split(rotary_emb.mrope_section, dim=-1)) ], dim=-1) sin_cache = torch.cat([ m[i] for i, m in enumerate( sin.split(rotary_emb.mrope_section, dim=-1)) ], dim=-1) return cos_cache, sin_cache