#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) #pragma once #include #include #include #include #include #include namespace at { struct Generator; struct XPUGeneratorState; namespace xpu { TORCH_XPU_API MempoolId_t graph_pool_handle(); using xpuGraph_t = sycl::ext::oneapi::experimental::command_graph< sycl::ext::oneapi::experimental::graph_state::modifiable>; using xpuGraphExec_t = sycl::ext::oneapi::experimental::command_graph< sycl::ext::oneapi::experimental::graph_state::executable>; struct TORCH_XPU_API XPUGraph { XPUGraph(bool keep_graph = false); ~XPUGraph(); void register_generator_state( c10::intrusive_ptr state); void register_generator_state(const at::Generator& generator); void capture_begin(MempoolId_t pool = {0, 0}); void capture_end(); void instantiate(); void replay(); void reset(); MempoolId_t pool(); void enable_debug_mode(); void debug_dump(const std::string& debug_path); xpuGraph_t* raw_xpu_graph(); xpuGraphExec_t* raw_xpu_graph_exec(); protected: std::unique_ptr graph_; std::unique_ptr graph_exec_; bool has_graph_ = false; bool capture_ended_ = false; bool has_graph_exec_ = false; MempoolId_t mempool_id_; at::xpu::XPUStream capture_stream_; // GeneratorState and whole graph offset increments mapping ska::flat_hash_map, uint64_t> captured_generator_states_; static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1; c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE}; bool keep_graph_; }; } // namespace xpu } // namespace at #else #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)