/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkKdTreeGenerator_hxx #define itkKdTreeGenerator_hxx namespace itk { namespace Statistics { template KdTreeGenerator::KdTreeGenerator() { m_SourceSample = nullptr; m_BucketSize = 16; m_Subsample = SubsampleType::New(); m_MeasurementVectorSize = 0; } template void KdTreeGenerator::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); os << indent << "Source Sample: "; if (m_SourceSample != nullptr) { os << m_SourceSample << std::endl; } else { os << "not set." << std::endl; } os << indent << "Bucket Size: " << m_BucketSize << std::endl; os << indent << "MeasurementVectorSize: " << m_MeasurementVectorSize << std::endl; } template void KdTreeGenerator::SetSample(TSample * sample) { m_SourceSample = sample; m_Subsample->SetSample(sample); m_Subsample->InitializeWithAllInstances(); m_MeasurementVectorSize = sample->GetMeasurementVectorSize(); NumericTraits::SetLength(m_TempLowerBound, m_MeasurementVectorSize); NumericTraits::SetLength(m_TempUpperBound, m_MeasurementVectorSize); NumericTraits::SetLength(m_TempMean, m_MeasurementVectorSize); } template void KdTreeGenerator::SetBucketSize(unsigned int size) { m_BucketSize = size; } template void KdTreeGenerator::GenerateData() { if (m_SourceSample == nullptr) { return; } if (m_Tree.IsNull()) { m_Tree = KdTreeType::New(); m_Tree->SetSample(m_SourceSample); m_Tree->SetBucketSize(m_BucketSize); } SubsamplePointer subsample = this->GetSubsample(); // Sanity check. Verify that the subsample has measurement vectors of the // same length as the sample generated by the tree. if (this->GetMeasurementVectorSize() != subsample->GetMeasurementVectorSize()) { itkExceptionMacro("Measurement Vector Length mismatch"); } MeasurementVectorType lowerBound; NumericTraits::SetLength(lowerBound, m_MeasurementVectorSize); MeasurementVectorType upperBound; NumericTraits::SetLength(upperBound, m_MeasurementVectorSize); for (unsigned int d = 0; d < m_MeasurementVectorSize; ++d) { lowerBound[d] = NumericTraits::NonpositiveMin(); upperBound[d] = NumericTraits::max(); } KdTreeNodeType * root = this->GenerateTreeLoop(0, m_Subsample->Size(), lowerBound, upperBound, 0); m_Tree->SetRoot(root); } template inline auto KdTreeGenerator::GenerateNonterminalNode(unsigned int beginIndex, unsigned int endIndex, MeasurementVectorType & lowerBound, MeasurementVectorType & upperBound, unsigned int level) -> KdTreeNodeType * { using NodeType = typename KdTreeType::KdTreeNodeType; MeasurementType dimensionLowerBound; MeasurementType dimensionUpperBound; MeasurementType partitionValue; unsigned int partitionDimension = 0; unsigned int i; MeasurementType spread; MeasurementType maxSpread; unsigned int medianIndex; SubsamplePointer subsample = this->GetSubsample(); // find most widely spread dimension Algorithm::FindSampleBoundAndMean( subsample, beginIndex, endIndex, m_TempLowerBound, m_TempUpperBound, m_TempMean); maxSpread = NumericTraits::NonpositiveMin(); for (i = 0; i < m_MeasurementVectorSize; ++i) { spread = m_TempUpperBound[i] - m_TempLowerBound[i]; if (spread >= maxSpread) { maxSpread = spread; partitionDimension = i; } } medianIndex = (endIndex - beginIndex) / 2; // // Find the medial element by using the NthElement function // based on the STL implementation of the QuickSelect algorithm. // partitionValue = Algorithm::NthElement(m_Subsample, partitionDimension, beginIndex, endIndex, medianIndex); medianIndex += beginIndex; // save bounds for cutting dimension dimensionLowerBound = lowerBound[partitionDimension]; dimensionUpperBound = upperBound[partitionDimension]; upperBound[partitionDimension] = partitionValue; const unsigned int beginLeftIndex = beginIndex; const unsigned int endLeftIndex = medianIndex; NodeType * left = GenerateTreeLoop(beginLeftIndex, endLeftIndex, lowerBound, upperBound, level + 1); upperBound[partitionDimension] = dimensionUpperBound; lowerBound[partitionDimension] = partitionValue; const unsigned int beginRightIndex = medianIndex + 1; const unsigned int endRightIndex = endIndex; NodeType * right = GenerateTreeLoop(beginRightIndex, endRightIndex, lowerBound, upperBound, level + 1); lowerBound[partitionDimension] = dimensionLowerBound; using KdTreeNonterminalNodeType = KdTreeNonterminalNode; auto * nonTerminalNode = new KdTreeNonterminalNodeType(partitionDimension, partitionValue, left, right); nonTerminalNode->AddInstanceIdentifier(subsample->GetInstanceIdentifier(medianIndex)); return nonTerminalNode; } template inline auto KdTreeGenerator::GenerateTreeLoop(unsigned int beginIndex, unsigned int endIndex, MeasurementVectorType & lowerBound, MeasurementVectorType & upperBound, unsigned int level) -> KdTreeNodeType * { if (endIndex - beginIndex <= m_BucketSize) { // numberOfInstances small, make a terminal node if (endIndex == beginIndex) { // return the pointer to empty terminal node return m_Tree->GetEmptyTerminalNode(); } else { auto * ptr = new KdTreeTerminalNode(); for (unsigned int j = beginIndex; j < endIndex; ++j) { ptr->AddInstanceIdentifier(this->GetSubsample()->GetInstanceIdentifier(j)); } // return a terminal node return ptr; } } else { return this->GenerateNonterminalNode(beginIndex, endIndex, lowerBound, upperBound, level + 1); } } } // end of namespace Statistics } // end of namespace itk #endif