#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) #pragma once #include #include #include #include namespace at { namespace xpu { struct XPUGraph; } struct XPUGeneratorState : public c10::intrusive_ptr_target { uint64_t seed_; uint64_t philox_offset_per_thread_; uint32_t offset_intragraph_; bool capturing_{}; std::unordered_set registered_graphs_; at::TensorBase seed_extragraph_{}; at::TensorBase offset_extragraph_{}; XPUGeneratorState( uint64_t seed = default_rng_seed_val, uint64_t philox_offset_per_thread = 0, uint32_t offset_intragraph = 0) : seed_(seed), philox_offset_per_thread_(philox_offset_per_thread), offset_intragraph_(offset_intragraph) {} void increase(uint64_t increment); void register_graph(xpu::XPUGraph* graph); void unregister_graph(xpu::XPUGraph* graph); void capture_prologue(); uint64_t capture_epilogue(); void replay_prologue(uint64_t wholegraph_increment); c10::intrusive_ptr clone(); }; struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { // Constructors XPUGeneratorImpl(DeviceIndex device_index = -1); XPUGeneratorImpl( DeviceIndex device_index, c10::intrusive_ptr state_); ~XPUGeneratorImpl() override = default; // XPUGeneratorImpl methods std::shared_ptr clone() const; void set_current_seed(uint64_t seed) override; void set_offset(uint64_t offset) override; uint64_t get_offset() const override; uint64_t current_seed() const override; uint64_t seed() override; void set_state(const c10::TensorImpl& new_state) override; c10::intrusive_ptr get_state() const override; void graphsafe_set_state( const c10::intrusive_ptr& state) override; c10::intrusive_ptr graphsafe_get_state() const override; void set_philox_offset_per_thread(uint64_t offset); uint64_t philox_offset_per_thread() const; void register_graph(xpu::XPUGraph* graph); void unregister_graph(xpu::XPUGraph* graph); PhiloxXpuState philox_xpu_state(uint64_t increment); std::pair philox_engine_inputs(uint64_t increment); static c10::DeviceType device_type(); private: XPUGeneratorImpl* clone_impl() const override; c10::intrusive_ptr state_; }; namespace xpu::detail { TORCH_XPU_API const Generator& getDefaultXPUGenerator(DeviceIndex device = -1); TORCH_XPU_API Generator createXPUGenerator(DeviceIndex device = -1); } // namespace xpu::detail } // namespace at #else #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)