#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) // Copyright © 2022 Apple Inc. #pragma once #include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #endif #include @interface MPSGraph (PyTorchFixups) - (MPSGraphTensor*)minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor secondaryTensor:(MPSGraphTensor*)secondaryTensor name:(NSString*)name; - (MPSGraphTensor*)maximumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor secondaryTensor:(MPSGraphTensor*)secondaryTensor name:(NSString*)name; @end using namespace at::mps; namespace at::native::mps { struct MPSScalar { id getMTLBuffer() const { return __builtin_bit_cast(id, buffer.get()); } size_t size = 0; ScalarType type = ScalarType::Undefined; c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope) union { float f; // MPS doesn't support 'double' at::Half h; int64_t i; bool b; c10::complex cf; c10::complex ch; at::BFloat16 bf16; } value{}; }; void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results); MPSDataType getMPSDataType(ScalarType scalar_type); static inline MPSDataType getMPSDataType(const TensorBase& t) { return getMPSDataType(t.scalar_type()); } MPSDataType getMPSScalarType(ScalarType scalar_type); static inline MPSDataType getMPSScalarType(const TensorBase& t) { return getMPSScalarType(t.scalar_type()); } MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); static inline std::string getMPSTypeString(const TensorBase& t, bool short_name = false) { return getMPSTypeString(t.scalar_type(), short_name); } std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); static inline std::string scalarToMetalTypeString(const TensorBase& t) { return scalarToMetalTypeString(t.scalar_type()); } NSArray* getTensorAxes(const TensorBase& t); NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim); std::string getMPSShapeString(MPSShape* shape); std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false); std::string to_hex_key(float); std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const Tensor& src, Tensor& dst); Tensor& scatterViewTensor(const Tensor& src, Tensor& output); MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray); MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil); // The MPSShape could vary based on memory format Tensor getTensorView(const Tensor& t, MPSShape* shape); MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); // Determines whether a tensor is too large to use MPSGraph bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI = true); static inline id getMTLBufferStorage(const TensorBase& tensor) { return __builtin_bit_cast(id, tensor.storage().data()); } class Placeholder { public: Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray); Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape* mpsShape = nullptr, bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true); MPSGraphTensor* getMPSGraphTensor() { return _placeholder; } MPSGraphTensorData* getMPSGraphTensorData() { return _value; } bool isIntermediate() { return _value == nullptr; } private: MPSGraphTensor* _placeholder; MPSGraphTensorData* _value; Tensor _tensor; }; void resize_tensor(Tensor* output); Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device); MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor); MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType); MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor); MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); MPSGraph* make_mps_graph(); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar); std::string get_mem_format_string(c10::MemoryFormat memory_format); using MPSCacheKey = uint64_t; struct MPSCachedKernel { MPSCachedKernel(NSObject* object) : _object([object retain]) {} virtual ~MPSCachedKernel() { [_object release]; _object = nullptr; } // Delete copy constructor and assignment MPSCachedKernel(const MPSCachedKernel&) = delete; void operator=(const MPSCachedKernel&) = delete; template inline T* kernel() const { return (T*)_object; } private: NSObject* _object = nullptr; }; // derive this class to cache a graph and its inputs/outputs // can be used to store any NSObject struct MPSCachedGraph { MPSCachedGraph(NSObject* object) : _object([object retain]) {} virtual ~MPSCachedGraph() { [_object release]; _object = nullptr; } template inline T* as() { return static_cast(this); } MPSGraph* graph() const { return (MPSGraph*)_object; } NSObject* object() const { return _object; } private: NSObject* _object = nullptr; }; struct MPSUnaryCachedGraph : public MPSCachedGraph { MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; }; struct MPSUnaryGradCachedGraph : public MPSCachedGraph { MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; // some backward input is actually the forward's output MPSGraphTensor* gradInputTensor_ = nil; }; struct MPSBinaryCachedGraph : public MPSCachedGraph { MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* otherTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; }; struct MPSBinaryGradCachedGraph : public MPSCachedGraph { MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* otherTensor_ = nil; MPSGraphTensor* gradInputTensor_ = nil; }; struct MPSKernelCache { typedef MPSCachedKernel* (^CreateCachedKernelBlock)(); struct CacheEntry { CacheEntry(const std::string& key, MPSCachedKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {} MPSCachedKernel* cachedKernel_ = nullptr; std::string key_; }; public: static MPSKernelCache* getInstance() { if (_instance_cache == nullptr) { _instance_cache = new MPSKernelCache(); } return _instance_cache; } ~MPSKernelCache() { dispatch_release(serialQueue_); for (const auto& i : cache_) { delete i.second.cachedKernel_; } } // Disallow the copy constructor and operator= functions MPSKernelCache(const MPSKernelCache&) = delete; void operator=(const MPSKernelCache&) = delete; MPSCachedKernel* CreateCachedKernel(const std::string& key, CreateCachedKernelBlock createCacheBlock) { __block MPSCachedKernel* cachedKernel = nil; MPSCacheKey hash = std::hash{}(key); dispatch_sync_with_rethrow(serialQueue_, ^() { if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); cachedKernel = entry.cachedKernel_; } else { cachedKernel = createCacheBlock(); CacheEntry entry(key, cachedKernel); cache_.emplace(hash, entry); } }); return cachedKernel; } template inline T* CreateCachedKernelAs(const std::string& key, CreateCachedKernelBlock createCacheBlock) { return static_cast(CreateCachedKernel(key, createCacheBlock)); } MPSCachedKernel* LookUp(const std::string& key) const { __block MPSCachedKernel* cachedKernel = nil; MPSCacheKey hash = std::hash{}(key); dispatch_sync_with_rethrow(serialQueue_, ^() { if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); cachedKernel = entry.cachedKernel_; } }); return cachedKernel; } template inline T* LookUpAs(const std::string& key) const { return static_cast(LookUp(key)); } private: MPSKernelCache() { serialQueue_ = dispatch_queue_create("kernel cache queue", DISPATCH_QUEUE_SERIAL); } static MPSKernelCache* _instance_cache; std::unordered_map cache_; dispatch_queue_t serialQueue_ = nullptr; }; // Common template for creating cached kernel if missing template inline T* LookUpOrCreateCachedKernel(const std::string& key, std::function instantiate) { auto cache_ = MPSKernelCache::getInstance(); if (auto rc = cache_->LookUpAs(key)) { return rc; } return cache_->CreateCachedKernelAs(key, ^mps::MPSCachedKernel*() { auto k_ = new mps::MPSCachedKernel(instantiate()); return k_; }); } // TODO: Improve the overall design of MPSGraphCache. // https://github.com/pytorch/pytorch/issues/77176 // Cache holding various keys mapped to graphs struct MPSGraphCache { typedef MPSCachedGraph* (^CreateCachedGraphBlock)(); struct CacheEntry { CacheEntry(const std::string& key, MPSCachedGraph* cachedGraph) : cachedGraph_(cachedGraph), key_(key) {} MPSCachedGraph* cachedGraph_ = nullptr; std::string key_; }; public: static MPSGraphCache* getInstance() { if (_instance_cache == nullptr) { _instance_cache = new MPSGraphCache(); } return _instance_cache; } ~MPSGraphCache() { dispatch_release(serialQueue_); for (const auto& i : cache_) { delete i.second.cachedGraph_; } } // Disallow the copy constructor and operator= functions MPSGraphCache(const MPSGraphCache&) = delete; void operator=(const MPSGraphCache&) = delete; MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { __block MPSCachedGraph* cachedGraph = nil; MPSCacheKey hash = std::hash{}(key); dispatch_sync_with_rethrow(serialQueue_, ^() { // verify the cached entry doesn't already exist if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); cachedGraph = entry.cachedGraph_; } else { cachedGraph = createCacheBlock(); CacheEntry entry(key, cachedGraph); cache_.emplace(hash, entry); profileCachedGraph(entry); } }); return cachedGraph; } template inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) { return static_cast(CreateCachedGraph(key, createCacheBlock)); } MPSCachedGraph* LookUp(const std::string& key) const { __block MPSCachedGraph* cachedGraph = nullptr; MPSCacheKey hash = std::hash{}(key); dispatch_sync(serialQueue_, ^() { if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); cachedGraph = entry.cachedGraph_; profileCachedGraph(entry); } }); return cachedGraph; } template inline T* LookUpAs(const std::string& key) const { return static_cast(LookUp(key)); } private: MPSGraphCache() { serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); } // this is defined in OperationUtils.mm to not include // MPSProfiler.h in header OperationUtils.h void profileCachedGraph(const CacheEntry& cacheEntry) const; static MPSGraphCache* _instance_cache; std::unordered_map cache_; dispatch_queue_t serialQueue_ = nullptr; }; // Common template for creating graph with a specified cache if missing template inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function instantiate) { auto cache_ = MPSGraphCache::getInstance(); if (auto rc = cache_->LookUpAs(key)) { return rc; } return cache_->CreateCachedGraphAs(key, ^mps::MPSCachedGraph*() { T* newCachedGraph = nil; @autoreleasepool { // Initialize graph auto mpsGraph = mps::make_mps_graph(); newCachedGraph = new T(mpsGraph); instantiate(mpsGraph, newCachedGraph); } return newCachedGraph; }); } // Common math operations MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); /** * Returns distance from lowest to highest element offset in given tensor. */ size_t compute_storage_numel_distance(const TensorBase& t); /** * Checks whether tensor is mapped to a contiguous area in the storage. */ inline bool is_dense_in_storage(const TensorBase& t) { return compute_storage_numel_distance(t) == static_cast(t.numel()); } template , encoder_t> || std::is_same_v, encoder_t>>> static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) { if (C10_UNLIKELY(t.device().type() == kCPU)) { if constexpr (std::is_same_v, encoder_t>) { TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op"); // MPS does not support doubles, silently downcast CPU scalar to float if (C10_UNLIKELY(t.scalar_type() == kDouble)) { auto val = static_cast(*reinterpret_cast(t.const_data_ptr())); [encoder setBytes:&val length:sizeof(val) atIndex:idx]; return; } if (C10_UNLIKELY(t.scalar_type() == kComplexDouble)) { auto val = static_cast>(*reinterpret_cast*>(t.const_data_ptr())); [encoder setBytes:&val length:sizeof(val) atIndex:idx]; return; } [encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx]; } else { TORCH_CHECK(false, "Passed CPU tensor to MPS op"); } return; } [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx]; } // Implementation of setBytes for containers vs trivially copiable types must be separate // Containers like `std::array` could have been uploaded directly, but `c10::ArrayRef`, // while trivially copiable, includes padding which if copied as Metal shader parameters // might overwrite other values template < typename T, typename = std::enable_if_t || std::is_same_v || (std::is_class_v && std::is_trivially_copyable_v && !detail::has_size_type_v)>> static inline void mtl_setBytes(id encoder, const T val, unsigned idx) { [encoder setBytes:&val length:sizeof(T) atIndex:idx]; } template >> static inline void mtl_setBytes(id encoder, const Container& values, unsigned idx) { [encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx]; } static inline void mtl_setBytes(id encoder, const MPSScalar& s, unsigned idx) { [encoder setBytes:&s.value length:s.size atIndex:idx]; } static size_t iter_tensor_offset(TensorIteratorBase& iter, unsigned idx) { // At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id object // But TensorIterator constructs data_ptr as if base was just a raw pointer // Workaround this problem by computing an offset from the start of the tensor, which works for both // tensor views and sliced 64-bit iterators return reinterpret_cast(iter.data_ptr(idx)) - reinterpret_cast(iter.tensor_base(idx).storage().data()); } static inline void bind_iter_tensors(id encoder, TensorIteratorBase& iter, std::optional ntensors = std::nullopt) { for (auto idx : c10::irange(ntensors.value_or(iter.ntensors()))) { auto& t = iter.tensor_base(idx); // Handle CPU scalars if (C10_UNLIKELY(t.device().type() == kCPU)) { mtl_setBuffer(encoder, t, idx); continue; } auto offs = iter_tensor_offset(iter, idx); [encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx]; } } namespace detail { template inline void mtl_setArg(id encoder, const T& val, unsigned idx) { mtl_setBytes(encoder, val, idx); } inline void mtl_setArg(id encoder, id val, unsigned idx) { [encoder setBuffer:val offset:0 atIndex:idx]; } template <> inline void mtl_setArg(id encoder, const Tensor& val, unsigned idx) { mtl_setBuffer(encoder, val, idx); } template <> inline void mtl_setArg(id encoder, const std::optional& val, unsigned idx) { if (val.has_value()) { mtl_setBuffer(encoder, val.value(), idx); } } template <> inline void mtl_setArg(id encoder, const TensorBase& val, unsigned idx) { mtl_setBuffer(encoder, val, idx); } // MPS does not support doubles, so cast it down to float before passing as an argument template <> inline void mtl_setArg(id encoder, const double& val, unsigned idx) { float val_f = static_cast(val); mtl_setBytes(encoder, val_f, idx); } } // namespace detail template static inline void mtl_setArgs(id encoder, const T& val) { detail::mtl_setArg(encoder, val, idx); } template static inline void mtl_setArgs(id encoder, const T& val, Args&&... args) { detail::mtl_setArg(encoder, val, idx); mtl_setArgs(encoder, std::forward(args)...); } static inline void mtl_dispatch1DJob(id encoder, id cplState, NSUInteger length) { static_assert(sizeof(NSUInteger) == sizeof(uint64_t)); const auto maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; auto size = MTLSizeMake(length, 1, 1); auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1); [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; } id generateKernelDataOffsets(id commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false); inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) { return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()}; } inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) { return @{ p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), }; } inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) { return @{ p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), }; } inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) { return @{ p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), p4.getMPSGraphTensor() : p4.getMPSGraphTensorData(), }; } inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) { runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result)); } // MPS yet to support double types, but starting from MacOS 14, supports bfloat16 inline bool supportedFloatingType(ScalarType dtype) { return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; } inline bool supportedFloatingType(const TensorBase& t) { return supportedFloatingType(t.scalar_type()); } inline bool supportedFloatingOrComplexType(ScalarType dtype) { if (dtype == kComplexFloat || dtype == kComplexHalf) { return true; } return supportedFloatingType(dtype); } inline bool supportedFloatingOrComplexType(const TensorBase& t) { return supportedFloatingOrComplexType(t.scalar_type()); } inline bool needsGather(const TensorBase& t) { static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); } template void MetalShaderLibrary::exec_unary_kernel_with_params(TensorIteratorBase& iter, const std::string& name, T params, const std::string& params_type_name) { using namespace at::mps; // Decompose 64-bit tensor into 32-bit ones if (!iter.can_use_32bit_indexing()) { for (auto&& sub_iter : iter.with_32bit_indexing()) { exec_unary_kernel_with_params(sub_iter, name, params, params_type_name); } return; } auto inputTensor = iter.input(0); auto outputTensor = iter.output(0); uint32_t length = iter.numel(); if (length == 0) { return; } auto kernel_name = fmt::format("{}_{}_{}_{}{}", name, iter.is_contiguous() ? "dense" : "strided", scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(inputTensor), fmt::format("_{}", params_type_name)); @autoreleasepool { auto cplState = getPipelineStateForFunc(kernel_name); MPSStream* mpsStream = getCurrentMPSStream(); dispatch_sync(mpsStream->queue(), ^() { auto computeEncoder = mpsStream->commandEncoder(); getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor}); [computeEncoder setComputePipelineState:cplState]; bind_iter_tensors(computeEncoder, iter); if (!iter.is_contiguous()) { mtl_setArgs<2>(computeEncoder, outputTensor.sizes(), inputTensor.strides(), outputTensor.strides(), inputTensor.ndimension()); } detail::mtl_setArg(computeEncoder, params, iter.is_contiguous() ? 2 : 6); mtl_dispatch1DJob(computeEncoder, cplState, length); getMPSProfiler().endProfileKernel(cplState); }); } } template void MetalShaderLibrary::exec_binary_kernel_with_params(TensorIteratorBase& iter, const std::string& name, T params, const std::string& params_type_name) { using namespace mps; // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?) // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with // double as common dtype (because Python floating point are always 64-bit values) TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS"); // Skip for empty iterators if (iter.numel() == 0) { return; } // Decompose 64-bit tensor into 32-bit ones if (!iter.can_use_32bit_indexing()) { for (auto&& sub_iter : iter.with_32bit_indexing()) { exec_binary_kernel_with_params(sub_iter, name, params, params_type_name); } return; } auto convert_double_scalar = [](Tensor& t) { if (t.dim() != 0) { return; } if (t.scalar_type() == kDouble) { t = t.to(kFloat); } else if (t.scalar_type() == kComplexDouble) { t = t.to(kComplexFloat); } }; Tensor input = iter.input(0); Tensor other = iter.input(1); Tensor out = iter.output(); convert_double_scalar(input); convert_double_scalar(other); MPSStream* mpsStream = getCurrentMPSStream(); const auto cast_needed = input.scalar_type() != other.scalar_type(); const auto suffix = iter.is_contiguous() ? "dense" : "strided"; // TODO: Implicitly pass both input and output types to non-cast kernels const auto kernel_name = cast_needed ? fmt::format("{}_{}_cast_{}_{}", name, suffix, scalarToMetalTypeString(out), params_type_name) : fmt::format("{}_{}_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input), params_type_name); dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { auto computeEncoder = mpsStream->commandEncoder(); auto binaryPSO = getPipelineStateForFunc(kernel_name); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other}); [computeEncoder setComputePipelineState:binaryPSO]; // Set input and output tensors bind_iter_tensors(computeEncoder, iter); // Iterator is contiguous if all of its elements are dense in storage, // i.e. it's true for both row-first and column-first tensors if (iter.is_contiguous()) { detail::mtl_setArg(computeEncoder, params, 3); if (cast_needed) { std::array size_and_types = {static_cast(c10::elementSize(input.scalar_type())), static_cast(c10::elementSize(other.scalar_type())), static_cast(input.scalar_type()), static_cast(other.scalar_type())}; mtl_setBytes(computeEncoder, size_and_types, 4); } } else { // Please note that shapes and strides of the iterator might be // different than that of its operands, for example binary op // between 4x4 tensor and scalar will result in 1D 16 element iterator std::array ndim_and_types = {iter.ndim(), static_cast(input.scalar_type()), static_cast(other.scalar_type()), static_cast(out.scalar_type())}; mtl_setArgs<3>( computeEncoder, params, iter.shape(), iter.strides(0), iter.strides(1), iter.strides(2), ndim_and_types); } mtl_dispatch1DJob(computeEncoder, binaryPSO, iter.numel()); getMPSProfiler().endProfileKernel(binaryPSO); } }); } // Checks if one tensor is broadcastable into another static bool is_dense_broadcastable(const Tensor& from, const Tensor& into) { if (!from.is_contiguous() || !into.is_contiguous()) { return false; } bool checking_squeezable_dims = false; for (const auto dim : c10::irange(from.ndimension())) { if (checking_squeezable_dims) { if (from.size(-dim - 1) == 1) { continue; } return false; } checking_squeezable_dims = from.size(-dim - 1) != into.size(-dim - 1); } return true; } } // namespace at::native::mps #else #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)