Files
enginex-vastai-va16-vllm/torch_vacc/vacc/custom_qwen3_ops.py
2026-04-02 04:55:00 +00:00

146 lines
4.7 KiB
Python

from typing import List, Optional, Tuple, Union
import torch
from torch import Generator
from torch_vacc._vacc_libs import _torch_vacc
def fuse_moe_prefill_stage0_qwen(
hidden_states,
rms_residual,
rms_weight,
gate_weight,
rms_hidden_state_opt: Optional[torch.Tensor] = None,
zero_moe_hidden_state_opt: Optional[torch.Tensor] = None,
topk_ids_opt: Optional[torch.Tensor] = None,
topk_weight_opt: Optional[torch.Tensor] = None,
):
return _torch_vacc.fuse_moe_prefill_stage0_qwen(
hidden_states,
rms_residual,
rms_weight,
gate_weight,
rms_hidden_state_opt,
zero_moe_hidden_state_opt,
topk_ids_opt,
topk_weight_opt,
)
def fuse_moe_decode_qwen(
hidden_states,
rms_residual,
rms_weight,
moe_weight_13,
moe_weight_2,
moe_weight_13_dequat,
moe_weight_2_dequant,
gate_weight,
block_size_13,
block_size_2,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] = None,
output: Optional[torch.Tensor] = None,
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_moe_decode_qwen(
hidden_states,
rms_residual,
rms_weight,
moe_weight_13,
moe_weight_2,
moe_weight_13_dequat,
moe_weight_2_dequant,
gate_weight,
block_size_13,
block_size_2,
world_size,
rank,
group_id,
dev_info,
output,
)
def rot_pos_emb_qwenvl(grid_thw: List[List[int]],
hidden_size: int,
head_num: int,
spatial_merge_size: int,
dtype: torch.dtype,
device: Union[int, str, torch.device] = "vacc"):
#assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}"
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("vacc", device)
thws = []
for i in grid_thw:
thws.extend(i)
return _torch_vacc.rot_pos_emb_qwenvl(thws,
hidden_size,
head_num,
spatial_merge_size,
dtype,
device)
def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor,
grid_thw: List[List[int]],
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int):
thws = []
for i in grid_thw:
thws.extend(i)
return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight,
thws,
num_grid_per_side,
spatial_merge_size,
hidden_dim)
# qwen2_vl and qwen3_vl img preocess op is same
def qwen2vl_img_preprocess(
image: "torch.Tensor",
do_resize: bool,
min_pixels: int,
max_pixels: int,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
resized_height: int,
resized_width: int,
interpolation: int, #Optional["F.InterpolationMode"],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
image_mean0: float,
image_mean1: float,
image_mean2: float,
image_std0: float,
image_std1: float,
image_std2: float,
# batch_size: int = 1,
# grid_t: int = 1,
# channel: int = 3,
# output: Optional[torch.Tensor] = None
):
assert image.device.type == "vacc", f"please target vacc device, now is {image.device}"
return _torch_vacc.qwen2vl_img_preprocess(
image,
do_resize,
min_pixels,
max_pixels,
do_rescale,
rescale_factor,
do_normalize,
resized_height,
resized_width,
interpolation,
patch_size,
temporal_patch_size,
merge_size,
image_mean0, image_mean1, image_mean2,
image_std0, image_std1, image_std2
)