146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
|
|
from typing import List, Optional, Tuple, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch import Generator
|
||
|
|
from torch_vacc._vacc_libs import _torch_vacc
|
||
|
|
|
||
|
|
|
||
|
|
def fuse_moe_prefill_stage0_qwen(
|
||
|
|
hidden_states,
|
||
|
|
rms_residual,
|
||
|
|
rms_weight,
|
||
|
|
gate_weight,
|
||
|
|
rms_hidden_state_opt: Optional[torch.Tensor] = None,
|
||
|
|
zero_moe_hidden_state_opt: Optional[torch.Tensor] = None,
|
||
|
|
topk_ids_opt: Optional[torch.Tensor] = None,
|
||
|
|
topk_weight_opt: Optional[torch.Tensor] = None,
|
||
|
|
):
|
||
|
|
return _torch_vacc.fuse_moe_prefill_stage0_qwen(
|
||
|
|
hidden_states,
|
||
|
|
rms_residual,
|
||
|
|
rms_weight,
|
||
|
|
gate_weight,
|
||
|
|
rms_hidden_state_opt,
|
||
|
|
zero_moe_hidden_state_opt,
|
||
|
|
topk_ids_opt,
|
||
|
|
topk_weight_opt,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def fuse_moe_decode_qwen(
|
||
|
|
hidden_states,
|
||
|
|
rms_residual,
|
||
|
|
rms_weight,
|
||
|
|
moe_weight_13,
|
||
|
|
moe_weight_2,
|
||
|
|
moe_weight_13_dequat,
|
||
|
|
moe_weight_2_dequant,
|
||
|
|
gate_weight,
|
||
|
|
block_size_13,
|
||
|
|
block_size_2,
|
||
|
|
world_size: int,
|
||
|
|
rank: int,
|
||
|
|
group_id: int,
|
||
|
|
dev_info: List[int] = None,
|
||
|
|
output: Optional[torch.Tensor] = None,
|
||
|
|
):
|
||
|
|
if 0 == len(dev_info):
|
||
|
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
||
|
|
return _torch_vacc.fuse_moe_decode_qwen(
|
||
|
|
hidden_states,
|
||
|
|
rms_residual,
|
||
|
|
rms_weight,
|
||
|
|
moe_weight_13,
|
||
|
|
moe_weight_2,
|
||
|
|
moe_weight_13_dequat,
|
||
|
|
moe_weight_2_dequant,
|
||
|
|
gate_weight,
|
||
|
|
block_size_13,
|
||
|
|
block_size_2,
|
||
|
|
world_size,
|
||
|
|
rank,
|
||
|
|
group_id,
|
||
|
|
dev_info,
|
||
|
|
output,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def rot_pos_emb_qwenvl(grid_thw: List[List[int]],
|
||
|
|
hidden_size: int,
|
||
|
|
head_num: int,
|
||
|
|
spatial_merge_size: int,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
device: Union[int, str, torch.device] = "vacc"):
|
||
|
|
#assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}"
|
||
|
|
if isinstance(device, str):
|
||
|
|
device = torch.device(device)
|
||
|
|
elif isinstance(device, int):
|
||
|
|
device = torch.device("vacc", device)
|
||
|
|
|
||
|
|
thws = []
|
||
|
|
for i in grid_thw:
|
||
|
|
thws.extend(i)
|
||
|
|
return _torch_vacc.rot_pos_emb_qwenvl(thws,
|
||
|
|
hidden_size,
|
||
|
|
head_num,
|
||
|
|
spatial_merge_size,
|
||
|
|
dtype,
|
||
|
|
device)
|
||
|
|
|
||
|
|
def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor,
|
||
|
|
grid_thw: List[List[int]],
|
||
|
|
num_grid_per_side: int,
|
||
|
|
spatial_merge_size: int,
|
||
|
|
hidden_dim: int):
|
||
|
|
thws = []
|
||
|
|
for i in grid_thw:
|
||
|
|
thws.extend(i)
|
||
|
|
return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight,
|
||
|
|
thws,
|
||
|
|
num_grid_per_side,
|
||
|
|
spatial_merge_size,
|
||
|
|
hidden_dim)
|
||
|
|
# qwen2_vl and qwen3_vl img preocess op is same
|
||
|
|
def qwen2vl_img_preprocess(
|
||
|
|
image: "torch.Tensor",
|
||
|
|
do_resize: bool,
|
||
|
|
min_pixels: int,
|
||
|
|
max_pixels: int,
|
||
|
|
do_rescale: bool,
|
||
|
|
rescale_factor: float,
|
||
|
|
do_normalize: bool,
|
||
|
|
resized_height: int,
|
||
|
|
resized_width: int,
|
||
|
|
interpolation: int, #Optional["F.InterpolationMode"],
|
||
|
|
patch_size: int,
|
||
|
|
temporal_patch_size: int,
|
||
|
|
merge_size: int,
|
||
|
|
image_mean0: float,
|
||
|
|
image_mean1: float,
|
||
|
|
image_mean2: float,
|
||
|
|
image_std0: float,
|
||
|
|
image_std1: float,
|
||
|
|
image_std2: float,
|
||
|
|
# batch_size: int = 1,
|
||
|
|
# grid_t: int = 1,
|
||
|
|
# channel: int = 3,
|
||
|
|
# output: Optional[torch.Tensor] = None
|
||
|
|
):
|
||
|
|
assert image.device.type == "vacc", f"please target vacc device, now is {image.device}"
|
||
|
|
return _torch_vacc.qwen2vl_img_preprocess(
|
||
|
|
image,
|
||
|
|
do_resize,
|
||
|
|
min_pixels,
|
||
|
|
max_pixels,
|
||
|
|
do_rescale,
|
||
|
|
rescale_factor,
|
||
|
|
do_normalize,
|
||
|
|
resized_height,
|
||
|
|
resized_width,
|
||
|
|
interpolation,
|
||
|
|
patch_size,
|
||
|
|
temporal_patch_size,
|
||
|
|
merge_size,
|
||
|
|
image_mean0, image_mean1, image_mean2,
|
||
|
|
image_std0, image_std1, image_std2
|
||
|
|
)
|