/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkOptImageToImageMetricsTest_h #define itkOptImageToImageMetricsTest_h #include "itkTimeProbe.h" #include "itkMersenneTwisterRandomVariateGenerator.h" namespace itk { template class OptImageToImageMetricsTest { public: OptImageToImageMetricsTest() = default; int RunTest(FixedImageType * fixed, MovingImageType * moving, InterpolatorType * interpolator, TransformType * transform, MetricType * metric, MetricInitializerType metricInitializer) { using ParametersType = typename MetricType::ParametersType; std::cout << "-------------------------------------------------------------------" << std::endl; std::cout << "Testing" << std::endl; std::cout << "\tMetric : " << metric->GetNameOfClass() << std::endl; std::cout << "\tInterpolator : " << interpolator->GetNameOfClass() << std::endl; std::cout << "\tTransform : " << transform->GetNameOfClass() << std::endl; std::cout << "-------------------------------------------------------------------" << std::endl; std::cout << std::endl; int result = EXIT_SUCCESS; // connect the interpolator metric->SetInterpolator(interpolator); // connect the transform metric->SetTransform(transform); // connect the images to the metric metric->SetFixedImage(fixed); metric->SetMovingImage(moving); // call custom initialization for the metric metricInitializer.Initialize(); // Always use the same seed value. // All instances are the same since MersenneTwisterRandomVariateGenerator // uses a singleton pattern. itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance()->SetSeed(42); // initialize the metric // Samples are drawn here in metric->Initialize(), // so we seed the random number generator // immediately before this call. metric->Initialize(); // Set the transform to identity transform->SetIdentity(); // Get the transform parameters for identity. ParametersType parameters = transform->GetParameters(); typename MetricType::MeasureType value; typename MetricType::DerivativeType derivative; // Try GetValue and GetDerivative... value = metric->GetValue(parameters); metric->GetDerivative(parameters, derivative); // Make a time probe itk::TimeProbe timeProbe; // Walk around the parameter value at parameterIdx for (unsigned int parameterIdx = 0; parameterIdx < parameters.GetSize(); ++parameterIdx) { std::cout << "Param[" << parameterIdx << "]\tValue\tDerivative " << std::endl; double startVal = parameters[parameterIdx]; // endVal is 10% beyond startVal. double endVal = 1.10 * startVal; // If startVal is 0, endVal needs to be fixed up. if (itk::Math::abs(endVal - 0.0) < 1e-8) { endVal = startVal + 1.0; } double incr = (endVal - startVal) / 10.0; for (double pval = startVal; pval <= endVal; pval += incr) { parameters[parameterIdx] = pval; timeProbe.Start(); metric->GetValueAndDerivative(parameters, value, derivative); timeProbe.Stop(); std::cout << pval << '\t' << value << '\t' << derivative << std::endl; } } std::cout << std::endl; std::cout << "Mean time for GetValueAndDerivative : " << timeProbe.GetMean() << std::endl; std::cout << std::endl; std::cout << "------------------------------Done---------------------------------" << std::endl; return result; } }; template class MeanSquaresMetricInitializer { public: using MetricType = itk::MeanSquaresImageToImageMetric; MeanSquaresMetricInitializer(MetricType * metric) { m_Metric = metric; } void Initialize() { // Do stuff on m_Metric m_Metric->UseAllPixelsOn(); } protected: MetricType * m_Metric; }; template class MattesMIMetricInitializer { public: using MetricType = itk::MattesMutualInformationImageToImageMetric; MattesMIMetricInitializer(MetricType * metric) { m_Metric = metric; } void Initialize() { // Do stuff on m_Metric m_Metric->SetNumberOfHistogramBins(50); m_Metric->SetNumberOfSpatialSamples(5000); } protected: MetricType * m_Metric; }; template class MIMetricInitializer { public: using MetricType = itk::MutualInformationImageToImageMetric; MIMetricInitializer(MetricType * metric) { m_Metric = metric; } void Initialize() { // Do stuff on m_Metric m_Metric->SetNumberOfSpatialSamples(400); } protected: MetricType * m_Metric; }; template void BasicTest(FixedImageReaderType * fixedImageReader, MovingImageReaderType * movingImageReader, InterpolatorType * interpolator, TransformType * transform) { using FixedImageType = typename FixedImageReaderType::OutputImageType; using MovingImageType = typename MovingImageReaderType::OutputImageType; fixedImageReader->Update(); movingImageReader->Update(); typename FixedImageType::Pointer fixed = fixedImageReader->GetOutput(); typename MovingImageType::Pointer moving = movingImageReader->GetOutput(); // Mean squares using MetricType = itk::MeanSquaresImageToImageMetric; auto msMetric = MetricType::New(); MeanSquaresMetricInitializer msMetricInitializer(msMetric); TestAMetric(fixedImageReader, movingImageReader, interpolator, transform, msMetric.GetPointer(), msMetricInitializer); // Mattes MI using MattesMetricType = itk::MattesMutualInformationImageToImageMetric; auto mattesMetric = MattesMetricType::New(); MattesMIMetricInitializer mattesMetricInitializer(mattesMetric); TestAMetric( fixedImageReader, movingImageReader, interpolator, transform, mattesMetric.GetPointer(), mattesMetricInitializer); } template void TestAMetric(FixedImageReaderType * fixedImageReader, MovingImageReaderType * movingImageReader, InterpolatorType * interpolator, TransformType * transform, MetricType * metric, MetricInitializerType metricInitializer) { using FixedImageType = typename FixedImageReaderType::OutputImageType; using MovingImageType = typename MovingImageReaderType::OutputImageType; metric->SetFixedImageRegion(fixedImageReader->GetOutput()->GetBufferedRegion()); OptImageToImageMetricsTest testMetric; testMetric.RunTest( fixedImageReader->GetOutput(), movingImageReader->GetOutput(), interpolator, transform, metric, metricInitializer); } template void AffineLinearTest(FixedImageReaderType * fixedImageReader, MovingImageReaderType * movingImageReader) { using MovingImageType = typename MovingImageReaderType::OutputImageType; using InterpolatorType = itk::LinearInterpolateImageFunction; using TransformType = itk::AffineTransform; auto interpolator = InterpolatorType::New(); auto transform = TransformType::New(); BasicTest(fixedImageReader, movingImageReader, interpolator.GetPointer(), transform.GetPointer()); } template void RigidLinearTest(FixedImageReaderType * fixedImageReader, MovingImageReaderType * movingImageReader) { using MovingImageType = typename MovingImageReaderType::OutputImageType; using InterpolatorType = itk::LinearInterpolateImageFunction; using TransformType = itk::Rigid2DTransform; auto interpolator = InterpolatorType::New(); auto transform = TransformType::New(); BasicTest(fixedImageReader, movingImageReader, interpolator.GetPointer(), transform.GetPointer()); } template void TranslationLinearTest(FixedImageReaderType * fixedImageReader, MovingImageReaderType * movingImageReader) { using MovingImageType = typename MovingImageReaderType::OutputImageType; using InterpolatorType = itk::LinearInterpolateImageFunction; using TransformType = itk::TranslationTransform; auto interpolator = InterpolatorType::New(); auto transform = TransformType::New(); BasicTest(fixedImageReader, movingImageReader, interpolator.GetPointer(), transform.GetPointer()); } template void DoDebugTest(FixedImageReaderType * fixedImageReader, MovingImageReaderType * movingImageReader) { using MovingImageType = typename MovingImageReaderType::OutputImageType; using InterpolatorType = itk::LinearInterpolateImageFunction; using TransformType = itk::Rigid2DTransform; auto interpolator = InterpolatorType::New(); auto transform = TransformType::New(); using FixedImageType = typename FixedImageReaderType::OutputImageType; using MovingImageType = typename MovingImageReaderType::OutputImageType; fixedImageReader->Update(); movingImageReader->Update(); typename FixedImageType::Pointer fixed = fixedImageReader->GetOutput(); typename MovingImageType::Pointer moving = movingImageReader->GetOutput(); // Mean squares using MetricType = itk::MeanSquaresImageToImageMetric; auto metric = MetricType::New(); MeanSquaresMetricInitializer metricInitializer(metric); metric->SetFixedImageRegion(fixedImageReader->GetOutput()->GetBufferedRegion()); using ParametersType = typename MetricType::ParametersType; std::cout << "-------------------------------------------------------------------" << std::endl; std::cout << "Testing" << std::endl; std::cout << "\tMetric : " << metric->GetNameOfClass() << std::endl; std::cout << "\tInterpolator : " << interpolator->GetNameOfClass() << std::endl; std::cout << "\tTransform : " << transform->GetNameOfClass() << std::endl; std::cout << "-------------------------------------------------------------------" << std::endl; std::cout << std::endl; // connect the interpolator metric->SetInterpolator(interpolator); // connect the transform metric->SetTransform(transform); // connect the images to the metric metric->SetFixedImage(fixed); metric->SetMovingImage(moving); // call custom initialization for the metric metricInitializer.Initialize(); // initialize the metric metric->Initialize(); // Set the transform to identity transform->SetIdentity(); // Get the transform parameters for identity. ParametersType parameters = transform->GetParameters(); typename MetricType::MeasureType value; typename MetricType::DerivativeType derivative; parameters[0] = 0.1; metric->GetValueAndDerivative(parameters, value, derivative); // metric->StopDebug(); // Force the test to end here so the debug file // ends at the right place. exit(EXIT_SUCCESS); } } // end namespace itk #endif