from typing import List, Optional, Tuple, Union import torch from torch import Generator from torch_vacc._vacc_libs import _torch_vacc def fuse_moe_prefill_stage0_qwen( hidden_states, rms_residual, rms_weight, gate_weight, rms_hidden_state_opt: Optional[torch.Tensor] = None, zero_moe_hidden_state_opt: Optional[torch.Tensor] = None, topk_ids_opt: Optional[torch.Tensor] = None, topk_weight_opt: Optional[torch.Tensor] = None, ): return _torch_vacc.fuse_moe_prefill_stage0_qwen( hidden_states, rms_residual, rms_weight, gate_weight, rms_hidden_state_opt, zero_moe_hidden_state_opt, topk_ids_opt, topk_weight_opt, ) def fuse_moe_decode_qwen( hidden_states, rms_residual, rms_weight, moe_weight_13, moe_weight_2, moe_weight_13_dequat, moe_weight_2_dequant, gate_weight, block_size_13, block_size_2, world_size: int, rank: int, group_id: int, dev_info: List[int] = None, output: Optional[torch.Tensor] = None, ): if 0 == len(dev_info): dev_info = [i | (i << 16) for i in range(world_size)] return _torch_vacc.fuse_moe_decode_qwen( hidden_states, rms_residual, rms_weight, moe_weight_13, moe_weight_2, moe_weight_13_dequat, moe_weight_2_dequant, gate_weight, block_size_13, block_size_2, world_size, rank, group_id, dev_info, output, ) def rot_pos_emb_qwenvl(grid_thw: List[List[int]], hidden_size: int, head_num: int, spatial_merge_size: int, dtype: torch.dtype, device: Union[int, str, torch.device] = "vacc"): #assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}" if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device("vacc", device) thws = [] for i in grid_thw: thws.extend(i) return _torch_vacc.rot_pos_emb_qwenvl(thws, hidden_size, head_num, spatial_merge_size, dtype, device) def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor, grid_thw: List[List[int]], num_grid_per_side: int, spatial_merge_size: int, hidden_dim: int): thws = [] for i in grid_thw: thws.extend(i) return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight, thws, num_grid_per_side, spatial_merge_size, hidden_dim) # qwen2_vl and qwen3_vl img preocess op is same def qwen2vl_img_preprocess( image: "torch.Tensor", do_resize: bool, min_pixels: int, max_pixels: int, do_rescale: bool, rescale_factor: float, do_normalize: bool, resized_height: int, resized_width: int, interpolation: int, #Optional["F.InterpolationMode"], patch_size: int, temporal_patch_size: int, merge_size: int, image_mean0: float, image_mean1: float, image_mean2: float, image_std0: float, image_std1: float, image_std2: float, # batch_size: int = 1, # grid_t: int = 1, # channel: int = 3, # output: Optional[torch.Tensor] = None ): assert image.device.type == "vacc", f"please target vacc device, now is {image.device}" return _torch_vacc.qwen2vl_img_preprocess( image, do_resize, min_pixels, max_pixels, do_rescale, rescale_factor, do_normalize, resized_height, resized_width, interpolation, patch_size, temporal_patch_size, merge_size, image_mean0, image_mean1, image_mean2, image_std0, image_std1, image_std2 )