/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #include "itkMeanSquaresImageToImageMetricv4.h" #include "itkMattesMutualInformationImageToImageMetricv4.h" #include "itkJointHistogramMutualInformationImageToImageMetricv4.h" #include "itkANTSNeighborhoodCorrelationImageToImageMetricv4.h" #include "itkCorrelationImageToImageMetricv4.h" #include "itkTranslationTransform.h" #include "itkLinearInterpolateImageFunction.h" #include "itkImage.h" #include "itkGaussianImageSource.h" #include "itkCyclicShiftImageFilter.h" #include "itkRegistrationParameterScalesFromPhysicalShift.h" #include "itkGradientDescentOptimizerv4.h" #include "itkImageRegionIteratorWithIndex.h" /* This test performs a simple registration test on each * ImageToImageMetricv4 metric, testing that: * 1) metric value is minimized * 2) final optimization position is correct within tolerance * 3) different options for sampling and image gradient calculation work * New metrics must be added manually to this test. */ template int ImageToImageMetricv4RegistrationTestRun(typename TMetric::Pointer metric, int numberOfIterations, typename TImage::PixelType maximumStepSize, bool doSampling, bool doGradientFilter) { using PixelType = typename TImage::PixelType; using CoordinateRepresentationType = PixelType; // Create two simple images itk::SizeValueType ImageSize = 100; itk::OffsetValueType boundary = 6; if (Dimension == 3) { ImageSize = 60; boundary = 4; } // Declare Gaussian Sources using GaussianImageSourceType = itk::GaussianImageSource; typename TImage::SizeType size; size.Fill(ImageSize); typename TImage::SpacingType spacing; spacing.Fill(itk::NumericTraits::OneValue()); typename TImage::PointType origin; origin.Fill(CoordinateRepresentationType{}); typename TImage::DirectionType direction; direction.Fill(itk::NumericTraits::OneValue()); auto fixedImageSource = GaussianImageSourceType::New(); fixedImageSource->SetSize(size); fixedImageSource->SetOrigin(origin); fixedImageSource->SetSpacing(spacing); fixedImageSource->SetNormalized(false); fixedImageSource->SetScale(1.0f); fixedImageSource->Update(); typename TImage::Pointer fixedImage = fixedImageSource->GetOutput(); // zero-out the boundary itk::ImageRegionIteratorWithIndex it(fixedImage, fixedImage->GetLargestPossibleRegion()); for (it.GoToBegin(); !it.IsAtEnd(); ++it) { for (itk::SizeValueType n = 0; n < Dimension; ++n) { if (it.GetIndex()[n] < boundary || (static_cast(size[n]) - it.GetIndex()[n]) <= boundary) { it.Set(PixelType{}); break; } } } // shift the fixed image to get the moving image using CyclicShiftFilterType = itk::CyclicShiftImageFilter; auto shiftFilter = CyclicShiftFilterType::New(); typename CyclicShiftFilterType::OffsetType imageShift; typename CyclicShiftFilterType::OffsetValueType maxImageShift = boundary - 1; imageShift.Fill(maxImageShift); imageShift[0] = maxImageShift / 2; shiftFilter->SetInput(fixedImage); shiftFilter->SetShift(imageShift); shiftFilter->Update(); typename TImage::Pointer movingImage = shiftFilter->GetOutput(); // create an affine transform using TranslationTransformType = itk::TranslationTransform; auto translationTransform = TranslationTransformType::New(); translationTransform->SetIdentity(); // setup metric // metric->SetFixedImage(fixedImage); metric->SetMovingImage(movingImage); metric->SetMovingTransform(translationTransform); metric->SetUseMovingImageGradientFilter(doGradientFilter); metric->SetUseFixedImageGradientFilter(doGradientFilter); std::cout << "Use image gradient filter: " << doGradientFilter << std::endl; // sampling if (!doSampling) { std::cout << "Dense sampling." << std::endl; metric->SetUseSampledPointSet(false); } else { using PointSetType = typename TMetric::FixedSampledPointSetType; using PointType = typename PointSetType::PointType; typename PointSetType::Pointer pset(PointSetType::New()); itk::SizeValueType ind = 0, ct = 0; itk::ImageRegionIteratorWithIndex itS(fixedImage, fixedImage->GetLargestPossibleRegion()); for (itS.GoToBegin(); !itS.IsAtEnd(); ++itS) { // take every N^th point // not sampling sparsely in order to get all metrics to pass // with similar settings if (ct % 2 == 0) { PointType pt; fixedImage->TransformIndexToPhysicalPoint(itS.GetIndex(), pt); pset->SetPoint(ind, pt); ind++; } ct++; } std::cout << "Setting point set with " << ind << " points of " << fixedImage->GetLargestPossibleRegion().GetNumberOfPixels() << " total " << std::endl; metric->SetFixedSampledPointSet(pset); metric->SetUseSampledPointSet(true); std::cout << "Testing metric with point set..." << std::endl; } // initialize metric->Initialize(); // calculate initial metric value typename TMetric::MeasureType initialValue = metric->GetValue(); // scales estimator using RegistrationParameterScalesFromPhysicalShiftType = itk::RegistrationParameterScalesFromPhysicalShift; typename RegistrationParameterScalesFromPhysicalShiftType::Pointer shiftScaleEstimator = RegistrationParameterScalesFromPhysicalShiftType::New(); shiftScaleEstimator->SetMetric(metric); // // optimizer // using OptimizerType = itk::GradientDescentOptimizerv4; auto optimizer = OptimizerType::New(); optimizer->SetMetric(metric); optimizer->SetNumberOfIterations(numberOfIterations); optimizer->SetScalesEstimator(shiftScaleEstimator); if (maximumStepSize > 0) { optimizer->SetMaximumStepSizeInPhysicalUnits(maximumStepSize); } optimizer->StartOptimization(); std::cout << "image size: " << size; std::cout << ", # of iterations: " << optimizer->GetNumberOfIterations() << ", max step size: " << optimizer->GetMaximumStepSizeInPhysicalUnits() << std::endl; std::cout << "imageShift: " << imageShift << std::endl; std::cout << "Transform final parameters: " << translationTransform->GetParameters() << std::endl; // final metric value typename TMetric::MeasureType finalValue = metric->GetValue(); std::cout << "metric value: initial: " << initialValue << ", final: " << finalValue << std::endl; // test that the final position is close to the truth double tolerance = 0.11; for (itk::SizeValueType n = 0; n < Dimension; ++n) { if (itk::Math::abs(1.0 - (static_cast(imageShift[n]) / translationTransform->GetParameters()[n])) > tolerance) { std::cerr << "XXX Failed. Final transform parameters are not within tolerance of image shift. XXX" << std::endl; return EXIT_FAILURE; } } // test that metric value is minimized if (finalValue >= initialValue) { std::cerr << "XXX Failed. Final metric value is not less than initial value. XXX" << std::endl; return EXIT_FAILURE; } return EXIT_SUCCESS; } ////////////////////////////////////////////////////////////// template int itkImageToImageMetricv4RegistrationTestRunAll(int argc, char * argv[]) { using ImageType = itk::Image; // options // we have two options for iterations and step size to accomodate // the different behavior of metrics int numberOfIterations1 = 50; typename ImageType::PixelType maximumStepSize1 = 1.0; int numberOfIterations2 = 120; typename ImageType::PixelType maximumStepSize2 = 0.1; bool doSampling = false; bool doGradientFilter = false; if (argc > 1) { numberOfIterations1 = std::stoi(argv[1]); } if (argc > 2) { maximumStepSize1 = std::stod(argv[2]); } if (argc > 3) { numberOfIterations2 = std::stoi(argv[3]); } if (argc > 4) { maximumStepSize2 = std::stod(argv[4]); } if (argc > 5) { doSampling = std::stoi(argv[5]); } if (argc > 6) { doGradientFilter = std::stoi(argv[6]); } std::cout << std::endl << "******************* Dimension: " << Dimension << std::endl; bool passed = true; // ANTS Neighborhood Correlation // This metric does not support sampling if (!doSampling) { using MetricType = itk::ANTSNeighborhoodCorrelationImageToImageMetricv4; auto metric = MetricType::New(); std::cout << std::endl << "*** ANTSNeighborhoodCorrelation metric: " << std::endl; if (ImageToImageMetricv4RegistrationTestRun( metric, numberOfIterations1, maximumStepSize1, doSampling, doGradientFilter) != EXIT_SUCCESS) { passed = false; } } // Correlation { using MetricType = itk::CorrelationImageToImageMetricv4; auto metric = MetricType::New(); std::cout << std::endl << "*** Correlation metric: " << std::endl; if (ImageToImageMetricv4RegistrationTestRun( metric, numberOfIterations1, maximumStepSize1, doSampling, doGradientFilter) != EXIT_SUCCESS) { passed = false; } } // Joint Histogram { using MetricType = itk::JointHistogramMutualInformationImageToImageMetricv4; auto metric = MetricType::New(); std::cout << std::endl << "*** JointHistogramMutualInformation metric: " << std::endl; if (ImageToImageMetricv4RegistrationTestRun( metric, numberOfIterations1, maximumStepSize1, doSampling, doGradientFilter) != EXIT_SUCCESS) { passed = false; } } // Mattes { using MetricType = itk::MattesMutualInformationImageToImageMetricv4; auto metric = MetricType::New(); std::cout << std::endl << "*** MattesMutualInformation metric: " << std::endl; if (ImageToImageMetricv4RegistrationTestRun( metric, numberOfIterations2, maximumStepSize2, doSampling, doGradientFilter) != EXIT_SUCCESS) { passed = false; } } // MeanSquares { using MetricType = itk::MeanSquaresImageToImageMetricv4; auto metric = MetricType::New(); std::cout << std::endl << "*** MeanSquares metric: " << std::endl; if (ImageToImageMetricv4RegistrationTestRun( metric, numberOfIterations1, maximumStepSize1, doSampling, doGradientFilter) != EXIT_SUCCESS) { passed = false; } } if (passed) { return EXIT_SUCCESS; } else { return EXIT_FAILURE; } } ////////////////////////////////////////////////////////////// int itkImageToImageMetricv4RegistrationTest(int argc, char * argv[]) { int result = EXIT_SUCCESS; if (itkImageToImageMetricv4RegistrationTestRunAll<2>(argc, argv) != EXIT_SUCCESS) { std::cerr << "Failed for one or more metrics. See error message(s) above." << std::endl; result = EXIT_FAILURE; } return result; }