/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #include "itkMersenneTwisterRandomVariateGenerator.h" #include "itkListSample.h" #include "itkKdTreeGenerator.h" #include "itkTestingMacros.h" #include #include int itkKdTreeTest3(int argc, char * argv[]) { if (argc < 5) { std::cerr << "Missing parameters." << std::endl; std::cerr << "Usage: " << itkNameOfTestExecutableMacro(argv); std::cerr << " numberOfDataPoints numberOfTestPoints " << "numberOfNeighbors bucketSize [graphvizDotOutputFile]" << std::endl; return EXIT_FAILURE; } // Random number generator using NumberGeneratorType = itk::Statistics::MersenneTwisterRandomVariateGenerator; NumberGeneratorType::Pointer randomNumberGenerator = NumberGeneratorType::GetInstance(); randomNumberGenerator->Initialize(); using MeasurementVectorType = itk::Array; using SampleType = itk::Statistics::ListSample; constexpr SampleType::MeasurementVectorSizeType measurementVectorSize = 2; auto sample = SampleType::New(); sample->SetMeasurementVectorSize(measurementVectorSize); // // Generate a sample of random points // const unsigned int numberOfDataPoints = std::stoi(argv[1]); MeasurementVectorType mv(measurementVectorSize); for (unsigned int i = 0; i < numberOfDataPoints; ++i) { mv[0] = randomNumberGenerator->GetNormalVariate(0.0, 1.0); mv[1] = randomNumberGenerator->GetNormalVariate(0.0, 1.0); sample->PushBack(mv); } using TreeGeneratorType = itk::Statistics::KdTreeGenerator; auto treeGenerator = TreeGeneratorType::New(); const unsigned int bucketSize = std::stoi(argv[4]); treeGenerator->SetSample(sample); treeGenerator->SetBucketSize(bucketSize); treeGenerator->Update(); using TreeType = TreeGeneratorType::KdTreeType; TreeType::Pointer tree = treeGenerator->GetOutput(); MeasurementVectorType queryPoint(measurementVectorSize); unsigned int numberOfNeighbors = std::stoi(argv[3]); if (numberOfNeighbors > numberOfDataPoints) { numberOfNeighbors = numberOfDataPoints; } TreeType::InstanceIdentifierVectorType neighbors1; TreeType::InstanceIdentifierVectorType neighbors2; MeasurementVectorType result(measurementVectorSize); MeasurementVectorType test_point(measurementVectorSize); MeasurementVectorType min_point(measurementVectorSize); unsigned int numberOfFailedPoints1 = 0; const unsigned int numberOfTestPoints = std::stoi(argv[2]); // // Check that for every point in the sample, its closest point is itself. // using DistanceMetricType = itk::Statistics::EuclideanDistanceMetric; using OriginType = DistanceMetricType::OriginType; auto distanceMetric = DistanceMetricType::New(); OriginType origin(measurementVectorSize); for (unsigned int k = 0; k < sample->Size(); ++k) { queryPoint = sample->GetMeasurementVector(k); for (unsigned int i = 0; i < sample->GetMeasurementVectorSize(); ++i) { origin[i] = queryPoint[i]; } distanceMetric->SetOrigin(origin); tree->Search(queryPoint, numberOfNeighbors, neighbors1); double max_distance = 0.; for (const auto i : neighbors1) { const double distance = distanceMetric->Evaluate(tree->GetMeasurementVector(i)); max_distance = std::max(distance, max_distance); } max_distance += itk::NumericTraits::epsilon() * 10.0; tree->Search(queryPoint, max_distance, neighbors2); for (size_t i = 0; i < neighbors2.size(); ++i) { auto temp_it = std::find(neighbors1.begin(), neighbors1.end(), neighbors2[i]); if (temp_it == neighbors1.end()) { std::cerr << "neighbors2[" << i << "] = " << neighbors2[i] << " is not in neighbors1" << std::endl; numberOfFailedPoints1++; } } for (size_t i = 0; i < neighbors1.size(); ++i) { auto temp_it = std::find(neighbors2.begin(), neighbors2.end(), neighbors1[i]); if (temp_it == neighbors2.end()) { std::cerr << "neighbors1[" << i << "] = " << neighbors1[i] << " is not in neighbors2" << std::endl; numberOfFailedPoints1++; } } } // // Generate a second sample of random points // and use them to query the tree // unsigned int numberOfFailedPoints2 = 0; for (unsigned int j = 0; j < numberOfTestPoints; ++j) { double max_distance = 0.; queryPoint[0] = randomNumberGenerator->GetNormalVariate(0.0, 1.0); queryPoint[1] = randomNumberGenerator->GetNormalVariate(0.0, 1.0); for (unsigned int i = 0; i < sample->GetMeasurementVectorSize(); ++i) { origin[i] = queryPoint[i]; } distanceMetric->SetOrigin(origin); tree->Search(queryPoint, numberOfNeighbors, neighbors1); for (const auto i : neighbors1) { const double distance = distanceMetric->Evaluate(tree->GetMeasurementVector(i)); max_distance = std::max(distance, max_distance); } max_distance += itk::NumericTraits::epsilon() * 10.0; tree->Search(queryPoint, max_distance, neighbors2); for (size_t i = 0; i < neighbors2.size(); ++i) { const double distance = distanceMetric->Evaluate(tree->GetMeasurementVector(neighbors2[i])); if (distance <= max_distance) { auto temp_it = std::find(neighbors1.begin(), neighbors1.end(), neighbors2[i]); if (temp_it == neighbors1.end()) { std::cerr << "neighbors2[" << i << "] = " << neighbors2[i] << " is not in neighbors1" << std::endl; numberOfFailedPoints2++; } } } } if (argc > 5) { // // Plot out the tree structure to the console in the format used by Graphviz dot // std::ofstream plotFile; plotFile.open(argv[5]); tree->PlotTree(plotFile); plotFile.close(); } if (numberOfFailedPoints1) { std::cerr << numberOfFailedPoints1 << " out of " << sample->Size(); std::cerr << " points failed to find themselves as closest-point" << std::endl; } if (numberOfFailedPoints2) { std::cerr << numberOfFailedPoints2 << " out of " << numberOfTestPoints; std::cerr << " points failed to find the correct closest point." << std::endl; } if (numberOfFailedPoints1 || numberOfFailedPoints2) { return EXIT_FAILURE; } std::cout << "Test PASSED." << std::endl; return EXIT_SUCCESS; }