Files

185 lines
4.8 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
from __future__ import annotations
from typing import Tuple
import torch
from ._device import (
current_device,
device,
device_count,
get_device_capability,
get_device_name,
get_device_properties,
is_available,
is_bf16_supported,
set_device,
synchronize,
)
from .amp import (
get_amp_supported_dtype,
get_autocast_dtype,
is_autocast_enabled,
set_autocast_dtype,
set_autocast_enabled,
)
from .lazy_initialize import _is_in_bad_fork, _lazy_call, _lazy_init
from .memory import ( # caching_allocator_alloc,; caching_allocator_delete,
empty_cache,
get_allocator_backend,
max_memory_allocated,
max_memory_cached,
max_memory_reserved,
mem_get_info,
memory_allocated,
memory_cached,
memory_reserved,
memory_snapshot,
memory_stats,
memory_stats_as_nested_dict,
memory_summary,
reset_accumulated_memory_stats,
reset_max_memory_allocated,
reset_max_memory_cached,
reset_peak_memory_stats,
set_per_process_memory_fraction,
)
from .streams import Event, Stream, current_stream, default_stream, set_stream, stream
def init():
r"""Initialize PyTorch's VACC state. You may need to call
this explicitly if you are interacting with PyTorch via
its C API, as Python bindings for VACC functionality will not
be available until this initialization takes place. Ordinary users
should not need this, as all of PyTorch's VACC methods
automatically initialize VACC state on-demand.
Does nothing if the VACC state is already initialized.
"""
_lazy_init()
# default_generators is empty util _lazy_init() is called
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
from .custom_ops import *
from .custom_qwen3_ops import *
from .random import * # noqa: F403
__all__ = [
"device",
"is_available",
"is_bf16_supported",
"current_device",
"set_device",
"device_count",
"get_device_properties",
"get_device_name",
"get_device_capability",
"synchronize",
"amp",
"get_amp_supported_dtype",
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"_is_in_bad_fork",
"_lazy_call",
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
"set_stream",
"current_stream",
"default_generators",
"default_stream",
"stream",
"Stream",
"Event",
"mem_get_info",
"set_per_process_memory_fraction",
"empty_cache",
"memory_stats",
"memory_stats_as_nested_dict",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"memory_cached",
"max_memory_cached",
"memory_snapshot",
"memory_summary",
"get_allocator_backend",
"rms_norm",
"RotaryPosEmbedding",
"scaled_dot_product_attention",
"scaled_dot_product_attention_cp_forward",
"scaled_dot_product_attention_cp_backward",
"swiglu",
"paged_attention",
"reshape_and_cache_attention",
"concat_and_cache_attention",
"w8a8_block_fp8_matmul",
"moe_expert_token_group_reassign",
"fused_mlp_mm_fp8",
"fused_mlp_fp8",
"fused_moe_preprocess",
"fused_residual_rmsnorm",
"parallel_embedding",
"all_reduce",
"all_gather",
"broadcast",
"fused_mlp_moe_with_rmsnorm",
"fuse_moe_decode_v2_allreduce",
"topk_topp",
"fused_mla",
"fused_mla_allreduce",
"fused_mlp_with_rmsnorm",
"fused_mlp_allreduce",
"ds3_sampler",
"sampler_v1",
"rejection_sampler",
"rejection_sampler_update_hidden_states",
"rejection_sampler_v1",
"fused_matmul_allgather",
"fused_mla_v2",
"fused_mla_allreduce_v2",
"mla_matmul_scale",
"mla_matmul",
"fused_mla_prefill_stage0",
"fused_mla_prefill_stage1",
"fused_mla_prefill_stage0_allreduce",
"fuse_moe_prefill_stage0",
"fuse_mla_mlp_v2_allreduce_decode",
"fuse_mla_moe_v2_allreduce_decode",
"fuse_mla_mlp_v2_allreduce_decode_layers",
"fuse_mla_moe_v2_allreduce_decode_layers",
"fuse_mla_mlp_v2_allreduce_decode_layers_v2",
"fuse_mla_moe_v2_allreduce_decode_layers_v2",
"fuse_mlp_qwen_int4",
"fuse_mlp_qwen_int4_reduce",
"w4a8_block_int4_matmul",
"fuse_atten_qwen3",
"fuse_atten_qwen2",
"qwen3_fuse_attention_moe_decode",
"fuse_mtp_stage0",
"fuse_mtp_allreduce",
"roll_out",
"fused_experts_int4_prefill",
"fuse_bge_embedding_stage1",
"l2_norm",
"fuse_mlp_vision",
"patch_merger_vision",
"fuse_atten_vit",
"apply_penalties",
]