/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkKdTree_h #define itkKdTree_h #include #include #include "itkPoint.h" #include "itkSize.h" #include "itkObject.h" #include "itkArray.h" #include "itkSubsample.h" #include "itkEuclideanDistanceMetric.h" namespace itk { namespace Statistics { /** * \class KdTreeNode * \brief This class defines the interface of its derived classes. * * The methods defined in this class are a superset of the methods * defined in its subclasses. Therefore, the subclasses implements only * part of the methods. The template argument, TSample, can be any * subclass of the Sample class. * * There are two categories for the subclasses, terminal and nonterminal * nodes. The terminal nodes stores the instance identifiers belonging to * them, while the nonterminal nodes don't. Therefore, the * AddInstanceIdentifier and the GetInstanceIdentifier have meaning only * with the terminal ones. The terminal nodes don't have any child (left * or right). For terminal nodes, the GetParameters method is void. * * Recent API changes: * The static const macro to get the length of a measurement vector, * \c MeasurementVectorSize has been removed to allow the length of a measurement * vector to be specified at run time. The \c type alias for \c CentroidType has * been changed from Array to FixedArray. * * \sa KdTreeNonterminalNode, KdTreeWeightedCentroidNonterminalNode, * KdTreeTerminalNode * \ingroup ITKStatistics */ template struct ITK_TEMPLATE_EXPORT KdTreeNode { /** type alias for itself */ using Self = KdTreeNode; /** Measurement type, not the measurement vector type */ using MeasurementType = typename TSample::MeasurementType; /** Centroid type */ using CentroidType = Array; /** Instance identifier type (index value type for the measurement * vector in a sample */ using InstanceIdentifier = typename TSample::InstanceIdentifier; /** Returns true if the node is a terminal node, that is a node that * doesn't have any child. */ virtual bool IsTerminal() const = 0; /** Fills the partitionDimension (the dimension that was chosen to * split the measurement vectors belong to this node to the left and the * right child among k dimensions) and the partitionValue (the * measurement value on the partitionDimension divides the left and the * right child */ virtual void GetParameters(unsigned int &, MeasurementType &) const = 0; /** Returns the pointer to the left child of this node */ virtual Self * Left() = 0; /** Returns the const pointer to the left child of this node */ virtual const Self * Left() const = 0; /** Returns the pointer to the right child of this node */ virtual Self * Right() = 0; /** Returns the const pointer to the right child of this node */ virtual const Self * Right() const = 0; /** * Returns the number of measurement vectors under this node including * its children */ virtual unsigned int Size() const = 0; /** Returns the vector sum of the all measurement vectors under this node */ virtual void GetWeightedCentroid(CentroidType &) = 0; /** Returns the centroid. weighted centroid divided by the size */ virtual void GetCentroid(CentroidType &) = 0; /** Returns the instance identifier of the index-th measurement vector */ virtual InstanceIdentifier GetInstanceIdentifier(InstanceIdentifier) const = 0; /** Add an instance to this node */ virtual void AddInstanceIdentifier(InstanceIdentifier) = 0; /** Destructor */ virtual ~KdTreeNode() = default; // needed to subclasses will actually be deleted }; // end of class /** * \class KdTreeNonterminalNode * \brief This is a subclass of the KdTreeNode. * * KdTreeNonterminalNode doesn't store the information related with the * centroids. Therefore, the GetWeightedCentroid and the GetCentroid * methods are void. This class should have the left and the right * children. If we have a sample and want to generate a KdTree without * the centroid related information, we can use the KdTreeGenerator. * * \sa KdTreeNode, KdTreeWeightedCentroidNonterminalNode, KdTreeGenerator * \ingroup ITKStatistics */ template struct ITK_TEMPLATE_EXPORT KdTreeNonterminalNode : public KdTreeNode { using Superclass = KdTreeNode; using typename Superclass::MeasurementType; using typename Superclass::CentroidType; using typename Superclass::InstanceIdentifier; KdTreeNonterminalNode(unsigned int, MeasurementType, Superclass *, Superclass *); ~KdTreeNonterminalNode() override = default; bool IsTerminal() const override { return false; } void GetParameters(unsigned int &, MeasurementType &) const override; /** Returns the pointer to the left child of this node */ Superclass * Left() override { return m_Left; } /** Returns the pointer to the right child of this node */ Superclass * Right() override { return m_Right; } /** Returns the const pointer to the left child of this node */ const Superclass * Left() const override { return m_Left; } /** Returns the const pointer to the right child of this node */ const Superclass * Right() const override { return m_Right; } /** * Returns the number of measurement vectors under this node including * its children */ unsigned int Size() const override { return 0; } /** * Returns the vector sum of the all measurement vectors under this node. * Do nothing for this class. */ void GetWeightedCentroid(CentroidType &) override {} /** * Returns the centroid. weighted centroid divided by the size. Do nothing for * this class. */ void GetCentroid(CentroidType &) override {} /** * Returns the identifier of the only MeasurementVector associated with * this node in the tree. This MeasurementVector will be used later during * the distance computation when querying the tree. */ InstanceIdentifier GetInstanceIdentifier(InstanceIdentifier) const override { return this->m_InstanceIdentifier; } /** * Set the identifier of the node. */ void AddInstanceIdentifier(InstanceIdentifier valueId) override { this->m_InstanceIdentifier = valueId; } private: unsigned int m_PartitionDimension{}; MeasurementType m_PartitionValue{}; InstanceIdentifier m_InstanceIdentifier{}; Superclass * m_Left{}; Superclass * m_Right{}; }; // end of class /** * \class KdTreeWeightedCentroidNonterminalNode * \brief This is a subclass of the KdTreeNode. * * KdTreeNonterminalNode does have the information related with the * centroids. Therefore, the GetWeightedCentroid and the GetCentroid * methods returns meaningful values. This class should have the left * and right children. If we have a sample and want to generate a KdTree * with the centroid related information, we can use the * WeightedCentroidKdTreeGenerator. The centroid, the weighted * centroid, and the size (the number of measurement vectors) can be * used to accelerate the k-means estimation. * * \sa KdTreeNode, KdTreeNonterminalNode, WeightedCentroidKdTreeGenerator * \ingroup ITKStatistics */ template struct ITK_TEMPLATE_EXPORT KdTreeWeightedCentroidNonterminalNode : public KdTreeNode { using Superclass = KdTreeNode; using typename Superclass::MeasurementType; using typename Superclass::CentroidType; using typename Superclass::InstanceIdentifier; using MeasurementVectorSizeType = typename TSample::MeasurementVectorSizeType; KdTreeWeightedCentroidNonterminalNode(unsigned int, MeasurementType, Superclass *, Superclass *, CentroidType &, unsigned int); ~KdTreeWeightedCentroidNonterminalNode() override = default; /** Not a terminal node. */ bool IsTerminal() const override { return false; } /** Return the parameters of the node. */ void GetParameters(unsigned int &, MeasurementType &) const override; /** Return the length of a measurement vector */ MeasurementVectorSizeType GetMeasurementVectorSize() const { return m_MeasurementVectorSize; } /** Return the left tree pointer. */ Superclass * Left() override { return m_Left; } /** Return the right tree pointer. */ Superclass * Right() override { return m_Right; } /** Return the left tree const pointer. */ const Superclass * Left() const override { return m_Left; } /** Return the right tree const pointer. */ const Superclass * Right() const override { return m_Right; } /** Return the size of the node. */ unsigned int Size() const override { return m_Size; } /** * Returns the vector sum of the all measurement vectors under this node. */ void GetWeightedCentroid(CentroidType & centroid) override { centroid = m_WeightedCentroid; } /** * Returns the centroid. weighted centroid divided by the size. */ void GetCentroid(CentroidType & centroid) override { centroid = m_Centroid; } /** * Returns the identifier of the only MeasurementVector associated with * this node in the tree. This MeasurementVector will be used later during * the distance computation when querying the tree. */ InstanceIdentifier GetInstanceIdentifier(InstanceIdentifier) const override { return this->m_InstanceIdentifier; } /** * Set the identifier of the node. */ void AddInstanceIdentifier(InstanceIdentifier valueId) override { this->m_InstanceIdentifier = valueId; } private: MeasurementVectorSizeType m_MeasurementVectorSize{}; unsigned int m_PartitionDimension{}; MeasurementType m_PartitionValue{}; CentroidType m_WeightedCentroid{}; CentroidType m_Centroid{}; InstanceIdentifier m_InstanceIdentifier{}; unsigned int m_Size{}; Superclass * m_Left{}; Superclass * m_Right{}; }; // end of class /** * \class KdTreeTerminalNode * \brief This class is the node that doesn't have any child node. The * IsTerminal method returns true for this class. This class stores the * instance identifiers belonging to this node, while the nonterminal * nodes do not store them. The AddInstanceIdentifier and * GetInstanceIdentifier are storing and retrieving the instance * identifiers belonging to this node. * * \sa KdTreeNode, KdTreeNonterminalNode, * KdTreeWeightedCentroidNonterminalNode * \ingroup ITKStatistics */ template struct ITK_TEMPLATE_EXPORT KdTreeTerminalNode : public KdTreeNode { using Superclass = KdTreeNode; using typename Superclass::MeasurementType; using typename Superclass::CentroidType; using typename Superclass::InstanceIdentifier; KdTreeTerminalNode() = default; ~KdTreeTerminalNode() override { this->m_InstanceIdentifiers.clear(); } /** A terminal node. */ bool IsTerminal() const override { return true; } /** Return the parameters of the node. */ void GetParameters(unsigned int &, MeasurementType &) const override {} /** Return the left tree pointer. Null for terminal nodes. */ Superclass * Left() override { return nullptr; } /** Return the right tree pointer. Null for terminal nodes. */ Superclass * Right() override { return nullptr; } /** Return the left tree const pointer. Null for terminal nodes. */ const Superclass * Left() const override { return nullptr; } /** Return the right tree const pointer. Null for terminal nodes. */ const Superclass * Right() const override { return nullptr; } /** Return the size of the node. */ unsigned int Size() const override { return static_cast(m_InstanceIdentifiers.size()); } /** * Returns the vector sum of the all measurement vectors under this node. * Do nothing for this case. */ void GetWeightedCentroid(CentroidType &) override {} /** * Returns the centroid. weighted centroid divided by the size. Do nothing * for this case. */ void GetCentroid(CentroidType &) override {} /** * Returns the identifier of the only MeasurementVector associated with * this node in the tree. This MeasurementVector will be used later during * the distance computation when querying the tree. */ InstanceIdentifier GetInstanceIdentifier(InstanceIdentifier index) const override { return m_InstanceIdentifiers[index]; } /** * Set the identifier of the node. */ void AddInstanceIdentifier(InstanceIdentifier id) override { m_InstanceIdentifiers.push_back(id); } private: std::vector m_InstanceIdentifiers{}; }; // end of class /** * \class KdTree * \brief This class provides methods for k-nearest neighbor search and * related data structures for a k-d tree. * * An object of this class stores instance identifiers in a k-d tree * that is a binary tree with children split along a dimension among * k-dimensions. The dimension of the split (or partition) is determined * for each nonterminal node that has two children. The split process is * terminated when the node has no children (when the number of * measurement vectors is less than or equal to the size set by the * SetBucketSize. That is The split process is a recursive process in * nature and in implementation. This implementation doesn't support * dynamic insert and delete operations for the tree. Instead, we can * use the KdTreeGenerator or WeightedCentroidKdTreeGenerator to * generate a static KdTree object. * * To search k-nearest neighbor, call the Search method with the query * point in a k-d space and the number of nearest neighbors. The * GetSearchResult method returns a pointer to a NearestNeighbors object * with k-nearest neighbors. * * Recent API changes: * The static const macro to get the length of a measurement vector, * 'MeasurementVectorSize' has been removed to allow the length of a measurement * vector to be specified at run time. Please use the function * GetMeasurementVectorSize() instead. * \sa KdTreeNode, KdTreeNonterminalNode, * KdTreeWeightedCentroidNonterminalNode, KdTreeTerminalNode, * KdTreeGenerator, WeightedCentroidKdTreeNode * \ingroup ITKStatistics */ template class ITK_TEMPLATE_EXPORT KdTree : public Object { public: ITK_DISALLOW_COPY_AND_MOVE(KdTree); /** Standard class type aliases */ using Self = KdTree; using Superclass = Object; using Pointer = SmartPointer; using ConstPointer = SmartPointer; /** \see LightObject::GetNameOfClass() */ itkOverrideGetNameOfClassMacro(KdTree); /** Method for creation through the object factory. */ itkNewMacro(Self); /** type alias alias for the source data container */ using SampleType = TSample; using MeasurementVectorType = typename TSample::MeasurementVectorType; using MeasurementType = typename TSample::MeasurementType; using InstanceIdentifier = typename TSample::InstanceIdentifier; using AbsoluteFrequencyType = typename TSample::AbsoluteFrequencyType; using MeasurementVectorSizeType = unsigned int; /** Get Macro to get the length of a measurement vector in the KdTree. * The length is obtained from the input sample. */ itkGetConstMacro(MeasurementVectorSize, MeasurementVectorSizeType); /** DistanceMetric type for the distance calculation and comparison */ using DistanceMetricType = EuclideanDistanceMetric; /** Node type of the KdTree */ using KdTreeNodeType = KdTreeNode; /** Neighbor type. The first element of the std::pair is the instance * identifier and the second one is the distance between the measurement * vector identified by the first element and the query point. */ using NeighborType = std::pair; using InstanceIdentifierVectorType = std::vector; /** * \class NearestNeighbors * \brief data structure for storing k-nearest neighbor search result * (k number of Neighbors) * * This class stores the instance identifiers and the distance values * of k-nearest neighbors. We can also query the farthest neighbor's * distance from the query point using the GetLargestDistance * method. * \ingroup ITKStatistics */ class NearestNeighbors { public: /** Constructor */ NearestNeighbors(std::vector & cache_vector) : m_FarthestNeighborIndex(0) , m_Distances(cache_vector) {} NearestNeighbors() = delete; /** Destructor */ ~NearestNeighbors() = default; /** Initialize the internal instance identifier and distance holders * with the size, k */ void resize(unsigned int k) { m_Identifiers.clear(); m_Identifiers.resize(k, NumericTraits::max()); m_Distances.clear(); m_Distances.resize(k, NumericTraits::max()); m_FarthestNeighborIndex = 0; } /** Returns the distance of the farthest neighbor from the query point */ double GetLargestDistance() { return m_Distances[m_FarthestNeighborIndex]; } /** Replaces the farthest neighbor's instance identifier and * distance value with the id and the distance */ void ReplaceFarthestNeighbor(InstanceIdentifier id, double distance) { m_Identifiers[m_FarthestNeighborIndex] = id; m_Distances[m_FarthestNeighborIndex] = distance; double farthestDistance = NumericTraits::min(); const auto size = static_cast(m_Distances.size()); for (unsigned int i = 0; i < size; ++i) { if (m_Distances[i] > farthestDistance) { farthestDistance = m_Distances[i]; m_FarthestNeighborIndex = i; } } } /** Returns the vector of k-neighbors' instance identifiers */ const InstanceIdentifierVectorType & GetNeighbors() const { return m_Identifiers; } /** Returns the instance identifier of the index-th neighbor among * k-neighbors */ InstanceIdentifier GetNeighbor(unsigned int index) const { return m_Identifiers[index]; } private: /** The index of the farthest neighbor among k-neighbors */ unsigned int m_FarthestNeighborIndex; /** Storage for the instance identifiers of k-neighbors */ InstanceIdentifierVectorType m_Identifiers; /** External storage for the distance values of k-neighbors * from the query point. This is a reference to external * vector to avoid unnecessary memory copying. */ std::vector & m_Distances; }; /** Sets the number of measurement vectors that can be stored in a * terminal node */ void SetBucketSize(unsigned int); /** Sets the input sample that provides the measurement vectors to the k-d * tree */ void SetSample(const TSample *); /** Returns the pointer to the input sample */ const TSample * GetSample() const { return m_Sample; } SizeValueType Size() const { return m_Sample->Size(); } /** Returns the pointer to the empty terminal node. A KdTree object * has a single empty terminal node in memory. when the split process * has to create an empty terminal node, the single instance is reused * for this case */ KdTreeNodeType * GetEmptyTerminalNode() { return m_EmptyTerminalNode; } /** Sets the root node of the KdTree that is a result of * KdTreeGenerator or WeightedCentroidKdTreeGenerator. */ void SetRoot(KdTreeNodeType * root) { if (this->m_Root) { this->DeleteNode(this->m_Root); } this->m_Root = root; } /** Returns the pointer to the root node. */ KdTreeNodeType * GetRoot() { return m_Root; } /** Returns the measurement vector identified by the instance * identifier that is an identifier defined for the input sample */ const MeasurementVectorType & GetMeasurementVector(InstanceIdentifier id) const { return m_Sample->GetMeasurementVector(id); } /** Returns the frequency of the measurement vector identified by * the instance identifier */ AbsoluteFrequencyType GetFrequency(InstanceIdentifier id) const { return m_Sample->GetFrequency(id); } /** Get the pointer to the distance metric. */ DistanceMetricType * GetDistanceMetric() { return m_DistanceMetric.GetPointer(); } /** Searches the k-nearest neighbors */ void Search(const MeasurementVectorType &, unsigned int, InstanceIdentifierVectorType &) const; /** Searches the k-nearest neighbors and returns * the distance vector along with the distance measures. */ void Search(const MeasurementVectorType &, unsigned int, InstanceIdentifierVectorType &, std::vector &) const; /** Searches the neighbors fallen into a hypersphere */ void Search(const MeasurementVectorType &, double, InstanceIdentifierVectorType &) const; /** Returns true if the intermediate k-nearest neighbors exist within * the bounding box defined by the lowerBound and the * upperBound. Otherwise returns false. Returns false if the ball * defined by the distance between the query point and the farthest * neighbor touch the surface of the bounding box. */ bool BallWithinBounds(const MeasurementVectorType &, MeasurementVectorType &, MeasurementVectorType &, double) const; /** Returns true if the ball defined by the distance between the query * point and the farthest neighbor overlaps with the bounding box * defined by the lower and the upper bounds. */ bool BoundsOverlapBall(const MeasurementVectorType &, MeasurementVectorType &, MeasurementVectorType &, double) const; /** Deletes the node recursively */ void DeleteNode(KdTreeNodeType *); /** Prints out the tree information */ void PrintTree(std::ostream &) const; /** Prints out the tree information */ void PrintTree(KdTreeNodeType *, unsigned int, unsigned int, std::ostream & os = std::cout) const; /** Draw out the tree information to a ostream using * the format of the Graphviz dot tool. */ void PlotTree(std::ostream & os) const; /** Prints out the tree information */ void PlotTree(KdTreeNodeType * node, std::ostream & os = std::cout) const; using Iterator = typename TSample::Iterator; using ConstIterator = typename TSample::ConstIterator; protected: /** Constructor */ KdTree(); /** Destructor: deletes the root node and the empty terminal node. */ ~KdTree() override; void PrintSelf(std::ostream & os, Indent indent) const override; /** search loop */ int NearestNeighborSearchLoop(const KdTreeNodeType *, const MeasurementVectorType &, MeasurementVectorType &, MeasurementVectorType &, NearestNeighbors &) const; /** search loop */ int SearchLoop(const KdTreeNodeType *, const MeasurementVectorType &, double, MeasurementVectorType &, MeasurementVectorType &, InstanceIdentifierVectorType &) const; private: /** Pointer to the input sample */ const TSample * m_Sample{}; /** Number of measurement vectors can be stored in a terminal node. */ int m_BucketSize{}; /** Pointer to the root node */ KdTreeNodeType * m_Root{}; /** Pointer to the empty terminal node */ KdTreeNodeType * m_EmptyTerminalNode{}; /** Distance metric smart pointer */ typename DistanceMetricType::Pointer m_DistanceMetric{}; /** Measurement vector size */ MeasurementVectorSizeType m_MeasurementVectorSize{}; }; // end of class } // end of namespace Statistics } // end of namespace itk #ifndef ITK_MANUAL_INSTANTIATION # include "itkKdTree.hxx" #endif #endif