/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkKdTreeBasedKmeansEstimator_hxx #define itkKdTreeBasedKmeansEstimator_hxx #include "itkStatisticsAlgorithm.h" namespace itk { namespace Statistics { template KdTreeBasedKmeansEstimator::KdTreeBasedKmeansEstimator() : m_KdTree(nullptr) , m_DistanceMetric(EuclideanDistanceMetric::New()) , m_MembershipFunctionsObject(MembershipFunctionVectorObjectType::New()) { m_TempVertex.Fill(0.0); } template void KdTreeBasedKmeansEstimator::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); os << indent << "Current Iteration: " << this->GetCurrentIteration() << std::endl; os << indent << "Maximum Iteration: " << this->GetMaximumIteration() << std::endl; os << indent << "Sum of Centroid Position Changes: " << this->GetCentroidPositionChanges() << std::endl; os << indent << "Threshold for the Sum of Centroid Position Changes: " << this->GetCentroidPositionChangesThreshold() << std::endl; os << indent << "Kd Tree:"; if (m_KdTree.IsNotNull()) { os << this->GetKdTree() << std::endl; } else { os << "not set." << std::endl; } os << indent << "Parameters: " << this->GetParameters() << std::endl; os << indent << "MeasurementVectorSize: " << this->GetMeasurementVectorSize() << std::endl; os << indent << "UseClusterLabels: " << this->GetUseClusterLabels() << std::endl; } template double KdTreeBasedKmeansEstimator::GetSumOfSquaredPositionChanges(InternalParametersType & previous, InternalParametersType & current) { double temp; double sum = 0.0; unsigned int i; for (i = 0; i < static_cast(previous.size()); ++i) { temp = m_DistanceMetric->Evaluate(previous[i], current[i]); sum += temp; } return sum; } template inline int KdTreeBasedKmeansEstimator::GetClosestCandidate(ParameterType & measurements, std::vector & validIndexes) { int closest = 0; double closestDistance = NumericTraits::max(); double tempDistance; auto iter = validIndexes.begin(); while (iter != validIndexes.end()) { tempDistance = m_DistanceMetric->Evaluate(m_CandidateVector[*iter].Centroid, measurements); if (tempDistance < closestDistance) { closest = *iter; closestDistance = tempDistance; } ++iter; } return closest; } template inline bool KdTreeBasedKmeansEstimator::IsFarther(ParameterType & pointA, ParameterType & pointB, MeasurementVectorType & lowerBound, MeasurementVectorType & upperBound) { // calculates the vertex of the Cell bounded by the lowerBound // and the upperBound for (unsigned int i = 0; i < m_MeasurementVectorSize; ++i) { if ((pointA[i] - pointB[i]) < 0.0) { m_TempVertex[i] = lowerBound[i]; } else { m_TempVertex[i] = upperBound[i]; } } if (m_DistanceMetric->Evaluate(pointA, m_TempVertex) >= m_DistanceMetric->Evaluate(pointB, m_TempVertex)) { return true; } return false; } template inline void KdTreeBasedKmeansEstimator::Filter(KdTreeNodeType * node, std::vector validIndexes, MeasurementVectorType & lowerBound, MeasurementVectorType & upperBound) { unsigned int i, j; typename TKdTree::InstanceIdentifier tempId; int closest; ParameterType individualPoint; NumericTraits::SetLength(individualPoint, this->m_MeasurementVectorSize); if (node->IsTerminal()) { // terminal node if (node == m_KdTree->GetEmptyTerminalNode()) { // empty node return; } for (i = 0; i < static_cast(node->Size()); ++i) { tempId = node->GetInstanceIdentifier(i); this->GetPoint(individualPoint, m_KdTree->GetMeasurementVector(tempId)); closest = this->GetClosestCandidate(individualPoint, validIndexes); for (j = 0; j < m_MeasurementVectorSize; ++j) { m_CandidateVector[closest].WeightedCentroid[j] += individualPoint[j]; } m_CandidateVector[closest].Size += 1; if (m_GenerateClusterLabels) { m_ClusterLabels[tempId] = closest; } } } else { CentroidType centroid; CentroidType weightedCentroid; ParameterType closestPosition; node->GetWeightedCentroid(weightedCentroid); node->GetCentroid(centroid); closest = this->GetClosestCandidate(centroid, validIndexes); closestPosition = m_CandidateVector[closest].Centroid; auto iter = validIndexes.begin(); while (iter != validIndexes.end()) { if (*iter != closest && this->IsFarther(m_CandidateVector[*iter].Centroid, closestPosition, lowerBound, upperBound)) { iter = validIndexes.erase(iter); continue; } if (iter != validIndexes.end()) { ++iter; } } if (validIndexes.size() == 1) { for (j = 0; j < m_MeasurementVectorSize; ++j) { m_CandidateVector[closest].WeightedCentroid[j] += weightedCentroid[j]; } m_CandidateVector[closest].Size += node->Size(); if (m_GenerateClusterLabels) { this->FillClusterLabels(node, closest); } } else { unsigned int partitionDimension; MeasurementType partitionValue; MeasurementType tempValue; node->GetParameters(partitionDimension, partitionValue); tempValue = upperBound[partitionDimension]; upperBound[partitionDimension] = partitionValue; this->Filter(node->Left(), validIndexes, lowerBound, upperBound); upperBound[partitionDimension] = tempValue; tempValue = lowerBound[partitionDimension]; lowerBound[partitionDimension] = partitionValue; this->Filter(node->Right(), validIndexes, lowerBound, upperBound); lowerBound[partitionDimension] = tempValue; } } } template void KdTreeBasedKmeansEstimator::FillClusterLabels(KdTreeNodeType * node, int closestIndex) { unsigned int i; if (node->IsTerminal()) { // terminal node if (node == m_KdTree->GetEmptyTerminalNode()) { // empty node return; } for (i = 0; i < static_cast(node->Size()); ++i) { m_ClusterLabels[node->GetInstanceIdentifier(i)] = closestIndex; } } else { this->FillClusterLabels(node->Left(), closestIndex); this->FillClusterLabels(node->Right(), closestIndex); } } template void KdTreeBasedKmeansEstimator::CopyParameters(ParametersType & source, InternalParametersType & target) { unsigned int i, j; int index = 0; for (i = 0; i < static_cast(source.size() / m_MeasurementVectorSize); ++i) { for (j = 0; j < m_MeasurementVectorSize; ++j) { target[i][j] = source[index]; ++index; } } } template void KdTreeBasedKmeansEstimator::CopyParameters(InternalParametersType & source, ParametersType & target) { unsigned int i, j; int index = 0; for (i = 0; i < static_cast(source.size()); ++i) { for (j = 0; j < m_MeasurementVectorSize; ++j) { target[index] = source[i][j]; ++index; } } } template void KdTreeBasedKmeansEstimator::CopyParameters(InternalParametersType & source, InternalParametersType & target) { unsigned int i, j; for (i = 0; i < static_cast(source.size()); ++i) { for (j = 0; j < m_MeasurementVectorSize; ++j) { target[i][j] = source[i][j]; } } } template void KdTreeBasedKmeansEstimator::StartOptimization() { unsigned int i; MeasurementVectorType lowerBound; MeasurementVectorType upperBound; NumericTraits::SetLength(lowerBound, this->m_MeasurementVectorSize); NumericTraits::SetLength(upperBound, this->m_MeasurementVectorSize); Algorithm::FindSampleBound( m_KdTree->GetSample(), m_KdTree->GetSample()->Begin(), m_KdTree->GetSample()->End(), lowerBound, upperBound); InternalParametersType previousPosition; // previousPosition.resize(m_Parameters.size() / m_MeasurementVectorSize); InternalParametersType currentPosition; // currentPosition.resize(m_Parameters.size() / m_MeasurementVectorSize); for (i = 0; i < m_Parameters.size() / m_MeasurementVectorSize; ++i) { ParameterType m; ParameterType m1; NumericTraits::SetLength(m, m_MeasurementVectorSize); NumericTraits::SetLength(m1, m_MeasurementVectorSize); previousPosition.push_back(m); currentPosition.push_back(m1); } this->CopyParameters(m_Parameters, currentPosition); m_CurrentIteration = 0; std::vector validIndexes; for (i = 0; i < static_cast(m_Parameters.size() / m_MeasurementVectorSize); ++i) { validIndexes.push_back(i); } m_GenerateClusterLabels = false; while (true) { this->CopyParameters(currentPosition, previousPosition); m_CandidateVector.SetCentroids(currentPosition); this->Filter(m_KdTree->GetRoot(), validIndexes, lowerBound, upperBound); m_CandidateVector.UpdateCentroids(); m_CandidateVector.GetCentroids(currentPosition); if (m_CurrentIteration >= m_MaximumIteration) { break; } m_CentroidPositionChanges = this->GetSumOfSquaredPositionChanges(previousPosition, currentPosition); if (m_CentroidPositionChanges <= m_CentroidPositionChangesThreshold) { break; } ++m_CurrentIteration; } if (m_UseClusterLabels) { m_GenerateClusterLabels = true; m_ClusterLabels.clear(); m_ClusterLabels.rehash(m_KdTree->GetSample()->Size()); for (i = 0; i < static_cast(m_Parameters.size() / m_MeasurementVectorSize); ++i) { validIndexes.push_back(i); } this->Filter(m_KdTree->GetRoot(), validIndexes, lowerBound, upperBound); } this->CopyParameters(currentPosition, m_Parameters); } template void KdTreeBasedKmeansEstimator::SetKdTree(TKdTree * tree) { m_KdTree = tree; m_MeasurementVectorSize = tree->GetMeasurementVectorSize(); m_DistanceMetric->SetMeasurementVectorSize(m_MeasurementVectorSize); NumericTraits::SetLength(m_TempVertex, m_MeasurementVectorSize); this->Modified(); } template const TKdTree * KdTreeBasedKmeansEstimator::GetKdTree() const { return m_KdTree.GetPointer(); } template auto KdTreeBasedKmeansEstimator::GetOutput() const -> const MembershipFunctionVectorObjectType * { // INSERT CHECKS if all the required inputs are set and optimization has been // run. unsigned int numberOfClasses = m_Parameters.size() / m_MeasurementVectorSize; MembershipFunctionVectorType & membershipFunctionsVector = m_MembershipFunctionsObject->Get(); for (unsigned int i = 0; i < numberOfClasses; ++i) { DistanceToCentroidMembershipFunctionPointer membershipFunction = DistanceToCentroidMembershipFunctionType::New(); membershipFunction->SetMeasurementVectorSize(m_MeasurementVectorSize); typename DistanceToCentroidMembershipFunctionType::CentroidType centroid; centroid.SetSize(m_MeasurementVectorSize); for (unsigned int j = 0; j < m_MeasurementVectorSize; ++j) { unsigned int parameterIndex = i * m_MeasurementVectorSize + j; centroid[j] = m_Parameters[parameterIndex]; } membershipFunction->SetCentroid(centroid); membershipFunctionsVector.push_back(membershipFunction); } return static_cast(m_MembershipFunctionsObject); } template void KdTreeBasedKmeansEstimator::GetPoint(ParameterType & point, MeasurementVectorType measurements) { for (unsigned int i = 0; i < m_MeasurementVectorSize; ++i) { point[i] = measurements[i]; } } template void KdTreeBasedKmeansEstimator::PrintPoint(ParameterType & point) { std::cout << "[ "; for (unsigned int i = 0; i < m_MeasurementVectorSize; ++i) { std::cout << point[i] << ' '; } std::cout << ']'; } } // end of namespace Statistics } // end namespace itk #endif