init
This commit is contained in:
41
vllm_vacc/vllm/model_executor/ops/mrope_op.py
Normal file
41
vllm_vacc/vllm/model_executor/ops/mrope_op.py
Normal file
@@ -0,0 +1,41 @@
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
||||
|
||||
|
||||
def get_sin_cos_mrope(rotary_emb: MRotaryEmbedding, positions: Union[torch.Tensor], use_fuse: bool = True):
|
||||
#ref: MRotaryEmbedding forward_native
|
||||
# odsp
|
||||
if use_fuse:
|
||||
return torch.vacc.mrope_get_sin_cos(
|
||||
rotary_emb.sin_cache,
|
||||
rotary_emb.cos_cache,
|
||||
positions,
|
||||
rotary_emb.mrope_section,
|
||||
rotary_emb.mrope_interleaved
|
||||
)
|
||||
|
||||
cos = rotary_emb.cos_cache[positions]
|
||||
sin = rotary_emb.sin_cache[positions]
|
||||
|
||||
if rotary_emb.mrope_interleaved:
|
||||
cos_cache = apply_interleaved_rope(cos, rotary_emb.mrope_section)
|
||||
sin_cache = apply_interleaved_rope(sin, rotary_emb.mrope_section)
|
||||
else:
|
||||
cos_cache = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
cos.split(rotary_emb.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin_cache = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
sin.split(rotary_emb.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
return cos_cache, sin_cache
|
||||
|
||||
Reference in New Issue
Block a user