init
This commit is contained in:
146
torch_vacc/vacc/custom_qwen3_ops.py
Normal file
146
torch_vacc/vacc/custom_qwen3_ops.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Generator
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
|
||||
def fuse_moe_prefill_stage0_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
gate_weight,
|
||||
rms_hidden_state_opt: Optional[torch.Tensor] = None,
|
||||
zero_moe_hidden_state_opt: Optional[torch.Tensor] = None,
|
||||
topk_ids_opt: Optional[torch.Tensor] = None,
|
||||
topk_weight_opt: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return _torch_vacc.fuse_moe_prefill_stage0_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
gate_weight,
|
||||
rms_hidden_state_opt,
|
||||
zero_moe_hidden_state_opt,
|
||||
topk_ids_opt,
|
||||
topk_weight_opt,
|
||||
)
|
||||
|
||||
|
||||
def fuse_moe_decode_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
moe_weight_13,
|
||||
moe_weight_2,
|
||||
moe_weight_13_dequat,
|
||||
moe_weight_2_dequant,
|
||||
gate_weight,
|
||||
block_size_13,
|
||||
block_size_2,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
group_id: int,
|
||||
dev_info: List[int] = None,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if 0 == len(dev_info):
|
||||
dev_info = [i | (i << 16) for i in range(world_size)]
|
||||
return _torch_vacc.fuse_moe_decode_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
moe_weight_13,
|
||||
moe_weight_2,
|
||||
moe_weight_13_dequat,
|
||||
moe_weight_2_dequant,
|
||||
gate_weight,
|
||||
block_size_13,
|
||||
block_size_2,
|
||||
world_size,
|
||||
rank,
|
||||
group_id,
|
||||
dev_info,
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
def rot_pos_emb_qwenvl(grid_thw: List[List[int]],
|
||||
hidden_size: int,
|
||||
head_num: int,
|
||||
spatial_merge_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[int, str, torch.device] = "vacc"):
|
||||
#assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}"
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device("vacc", device)
|
||||
|
||||
thws = []
|
||||
for i in grid_thw:
|
||||
thws.extend(i)
|
||||
return _torch_vacc.rot_pos_emb_qwenvl(thws,
|
||||
hidden_size,
|
||||
head_num,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
device)
|
||||
|
||||
def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor,
|
||||
grid_thw: List[List[int]],
|
||||
num_grid_per_side: int,
|
||||
spatial_merge_size: int,
|
||||
hidden_dim: int):
|
||||
thws = []
|
||||
for i in grid_thw:
|
||||
thws.extend(i)
|
||||
return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight,
|
||||
thws,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
hidden_dim)
|
||||
# qwen2_vl and qwen3_vl img preocess op is same
|
||||
def qwen2vl_img_preprocess(
|
||||
image: "torch.Tensor",
|
||||
do_resize: bool,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
resized_height: int,
|
||||
resized_width: int,
|
||||
interpolation: int, #Optional["F.InterpolationMode"],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
image_mean0: float,
|
||||
image_mean1: float,
|
||||
image_mean2: float,
|
||||
image_std0: float,
|
||||
image_std1: float,
|
||||
image_std2: float,
|
||||
# batch_size: int = 1,
|
||||
# grid_t: int = 1,
|
||||
# channel: int = 3,
|
||||
# output: Optional[torch.Tensor] = None
|
||||
):
|
||||
assert image.device.type == "vacc", f"please target vacc device, now is {image.device}"
|
||||
return _torch_vacc.qwen2vl_img_preprocess(
|
||||
image,
|
||||
do_resize,
|
||||
min_pixels,
|
||||
max_pixels,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
resized_height,
|
||||
resized_width,
|
||||
interpolation,
|
||||
patch_size,
|
||||
temporal_patch_size,
|
||||
merge_size,
|
||||
image_mean0, image_mean1, image_mean2,
|
||||
image_std0, image_std1, image_std2
|
||||
)
|
||||
Reference in New Issue
Block a user