42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
|
|
|
|
|
def get_sin_cos_mrope(rotary_emb: MRotaryEmbedding, positions: Union[torch.Tensor], use_fuse: bool = True):
|
|
#ref: MRotaryEmbedding forward_native
|
|
# odsp
|
|
if use_fuse:
|
|
return torch.vacc.mrope_get_sin_cos(
|
|
rotary_emb.sin_cache,
|
|
rotary_emb.cos_cache,
|
|
positions,
|
|
rotary_emb.mrope_section,
|
|
rotary_emb.mrope_interleaved
|
|
)
|
|
|
|
cos = rotary_emb.cos_cache[positions]
|
|
sin = rotary_emb.sin_cache[positions]
|
|
|
|
if rotary_emb.mrope_interleaved:
|
|
cos_cache = apply_interleaved_rope(cos, rotary_emb.mrope_section)
|
|
sin_cache = apply_interleaved_rope(sin, rotary_emb.mrope_section)
|
|
else:
|
|
cos_cache = torch.cat([
|
|
m[i] for i, m in enumerate(
|
|
cos.split(rotary_emb.mrope_section, dim=-1))
|
|
],
|
|
dim=-1)
|
|
sin_cache = torch.cat([
|
|
m[i] for i, m in enumerate(
|
|
sin.split(rotary_emb.mrope_section, dim=-1))
|
|
],
|
|
dim=-1)
|
|
|
|
return cos_cache, sin_cache
|
|
|