from __future__ import annotations from typing import Tuple import torch from ._device import ( current_device, device, device_count, get_device_capability, get_device_name, get_device_properties, is_available, is_bf16_supported, set_device, synchronize, ) from .amp import ( get_amp_supported_dtype, get_autocast_dtype, is_autocast_enabled, set_autocast_dtype, set_autocast_enabled, ) from .lazy_initialize import _is_in_bad_fork, _lazy_call, _lazy_init from .memory import ( # caching_allocator_alloc,; caching_allocator_delete, empty_cache, get_allocator_backend, max_memory_allocated, max_memory_cached, max_memory_reserved, mem_get_info, memory_allocated, memory_cached, memory_reserved, memory_snapshot, memory_stats, memory_stats_as_nested_dict, memory_summary, reset_accumulated_memory_stats, reset_max_memory_allocated, reset_max_memory_cached, reset_peak_memory_stats, set_per_process_memory_fraction, ) from .streams import Event, Stream, current_stream, default_stream, set_stream, stream def init(): r"""Initialize PyTorch's VACC state. You may need to call this explicitly if you are interacting with PyTorch via its C API, as Python bindings for VACC functionality will not be available until this initialization takes place. Ordinary users should not need this, as all of PyTorch's VACC methods automatically initialize VACC state on-demand. Does nothing if the VACC state is already initialized. """ _lazy_init() # default_generators is empty util _lazy_init() is called default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] from .custom_ops import * from .custom_qwen3_ops import * from .random import * # noqa: F403 __all__ = [ "device", "is_available", "is_bf16_supported", "current_device", "set_device", "device_count", "get_device_properties", "get_device_name", "get_device_capability", "synchronize", "amp", "get_amp_supported_dtype", "is_autocast_enabled", "set_autocast_enabled", "get_autocast_dtype", "set_autocast_dtype", "_is_in_bad_fork", "_lazy_call", "get_rng_state", "get_rng_state_all", "set_rng_state", "set_rng_state_all", "manual_seed", "manual_seed_all", "seed", "seed_all", "initial_seed", "set_stream", "current_stream", "default_generators", "default_stream", "stream", "Stream", "Event", "mem_get_info", "set_per_process_memory_fraction", "empty_cache", "memory_stats", "memory_stats_as_nested_dict", "reset_accumulated_memory_stats", "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "get_allocator_backend", "rms_norm", "RotaryPosEmbedding", "scaled_dot_product_attention", "scaled_dot_product_attention_cp_forward", "scaled_dot_product_attention_cp_backward", "swiglu", "paged_attention", "reshape_and_cache_attention", "concat_and_cache_attention", "w8a8_block_fp8_matmul", "moe_expert_token_group_reassign", "fused_mlp_mm_fp8", "fused_mlp_fp8", "fused_moe_preprocess", "fused_residual_rmsnorm", "parallel_embedding", "all_reduce", "all_gather", "broadcast", "fused_mlp_moe_with_rmsnorm", "fuse_moe_decode_v2_allreduce", "topk_topp", "fused_mla", "fused_mla_allreduce", "fused_mlp_with_rmsnorm", "fused_mlp_allreduce", "ds3_sampler", "sampler_v1", "rejection_sampler", "rejection_sampler_update_hidden_states", "rejection_sampler_v1", "fused_matmul_allgather", "fused_mla_v2", "fused_mla_allreduce_v2", "mla_matmul_scale", "mla_matmul", "fused_mla_prefill_stage0", "fused_mla_prefill_stage1", "fused_mla_prefill_stage0_allreduce", "fuse_moe_prefill_stage0", "fuse_mla_mlp_v2_allreduce_decode", "fuse_mla_moe_v2_allreduce_decode", "fuse_mla_mlp_v2_allreduce_decode_layers", "fuse_mla_moe_v2_allreduce_decode_layers", "fuse_mla_mlp_v2_allreduce_decode_layers_v2", "fuse_mla_moe_v2_allreduce_decode_layers_v2", "fuse_mlp_qwen_int4", "fuse_mlp_qwen_int4_reduce", "w4a8_block_int4_matmul", "fuse_atten_qwen3", "fuse_atten_qwen2", "qwen3_fuse_attention_moe_decode", "fuse_mtp_stage0", "fuse_mtp_allreduce", "roll_out", "fused_experts_int4_prefill", "fuse_bge_embedding_stage1", "l2_norm", "fuse_mlp_vision", "patch_merger_vision", "fuse_atten_vit", "apply_penalties", ]