/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkImageRegistrationMethodv4_hxx #define itkImageRegistrationMethodv4_hxx #include "itkSmoothingRecursiveGaussianImageFilter.h" #include "itkGradientDescentOptimizerv4.h" #include "itkImageRandomConstIteratorWithIndex.h" #include "itkImageRegionConstIteratorWithIndex.h" #include "itkImageToImageMetricv4.h" #include "itkIterationReporter.h" #include "itkMattesMutualInformationImageToImageMetricv4.h" #include "itkMersenneTwisterRandomVariateGenerator.h" #include "itkRegistrationParameterScalesFromPhysicalShift.h" #include "itkPrintHelper.h" namespace itk { template ImageRegistrationMethodv4::ImageRegistrationMethodv4() { ProcessObject::SetNumberOfRequiredOutputs(1); Self::SetPrimaryOutputName("Transform"); // indexed input are alternating fixed and moving images Self::SetPrimaryInputName("Fixed"); Self::AddRequiredInputName("Moving", 1); ProcessObject::SetNumberOfRequiredInputs(2); // optional named inputs Self::SetInput("InitialTransform", nullptr); Self::SetInput("FixedInitialTransform", nullptr); Self::SetInput("MovingInitialTransform", nullptr); this->m_VirtualDomainImage = nullptr; Self::ReleaseDataBeforeUpdateFlagOff(); this->m_CurrentLevel = 0; this->m_CurrentIteration = 0; this->m_CurrentMetricValue = 0.0; this->m_CurrentConvergenceValue = 0.0; this->m_IsConverged = false; this->m_NumberOfFixedObjects = 0; this->m_NumberOfMovingObjects = 0; Self::ReleaseDataBeforeUpdateFlagOff(); this->m_InPlace = true; this->m_InitializeCenterOfLinearOutputTransform = true; this->m_CompositeTransform = CompositeTransformType::New(); using DefaultMetricType = MattesMutualInformationImageToImageMetricv4; auto mutualInformationMetric = DefaultMetricType::New(); mutualInformationMetric->SetNumberOfHistogramBins(20); mutualInformationMetric->SetUseMovingImageGradientFilter(false); mutualInformationMetric->SetUseFixedImageGradientFilter(false); mutualInformationMetric->SetUseSampledPointSet(false); this->m_Metric = mutualInformationMetric; using DefaultScalesEstimatorType = RegistrationParameterScalesFromPhysicalShift; auto scalesEstimator = DefaultScalesEstimatorType::New(); scalesEstimator->SetMetric(mutualInformationMetric); scalesEstimator->SetTransformForward(true); using DefaultOptimizerType = GradientDescentOptimizerv4Template; auto optimizer = DefaultOptimizerType::New(); optimizer->SetLearningRate(1.0); optimizer->SetNumberOfIterations(1000); optimizer->SetScalesEstimator(scalesEstimator); this->m_Optimizer = optimizer; this->m_OptimizerWeights.SetSize(0); this->m_OptimizerWeightsAreIdentity = true; DecoratedOutputTransformPointer transformDecorator = itkDynamicCastInDebugMode(this->MakeOutput(0).GetPointer()); this->ProcessObject::SetNthOutput(0, transformDecorator); this->m_OutputTransform = transformDecorator->GetModifiable(); // Default to a 3-level image registration: force the initialization of the ivars that depend on the number of levels. this->SetNumberOfLevels(3); this->m_ShrinkFactorsPerLevel.resize(this->m_NumberOfLevels); ShrinkFactorsPerDimensionContainerType shrinkFactors; shrinkFactors.Fill(2); this->m_ShrinkFactorsPerLevel[0] = shrinkFactors; shrinkFactors.Fill(1); this->m_ShrinkFactorsPerLevel[1] = shrinkFactors; shrinkFactors.Fill(1); this->m_ShrinkFactorsPerLevel[2] = shrinkFactors; this->m_SmoothingSigmasPerLevel.SetSize(this->m_NumberOfLevels); this->m_SmoothingSigmasPerLevel[0] = 2; this->m_SmoothingSigmasPerLevel[1] = 1; this->m_SmoothingSigmasPerLevel[2] = 0; this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits = true; this->m_ReseedIterator = false; this->m_RandomSeed = Statistics::MersenneTwisterRandomVariateGenerator::GetNextSeed(); this->m_CurrentRandomSeed = this->m_RandomSeed; this->m_MetricSamplingStrategy = MetricSamplingStrategyEnum::NONE; this->m_MetricSamplingPercentagePerLevel.SetSize(this->m_NumberOfLevels); this->m_MetricSamplingPercentagePerLevel.Fill(1.0); } template void ImageRegistrationMethodv4::SetFixedImage( SizeValueType index, const FixedImageType * image) { itkDebugMacro("setting fixed image input " << index << " to " << image); if (image != static_cast(this->ProcessObject::GetInput(2 * index))) { if (!this->ProcessObject::GetInput(2 * index)) { this->m_NumberOfFixedObjects++; } this->ProcessObject::SetNthInput(2 * index, const_cast(image)); this->Modified(); } } template const typename ImageRegistrationMethodv4:: FixedImageType * ImageRegistrationMethodv4::GetFixedImage( SizeValueType index) const { itkDebugMacro("returning fixed image input " << index << " of " << static_cast(this->ProcessObject::GetInput(2 * index))); return static_cast(this->ProcessObject::GetInput(2 * index)); } template void ImageRegistrationMethodv4::SetMovingImage( SizeValueType index, const MovingImageType * image) { itkDebugMacro("setting moving image input " << index << " to " << image); if (image != static_cast(this->ProcessObject::GetInput(2 * index + 1))) { if (!this->ProcessObject::GetInput(2 * index + 1)) { this->m_NumberOfMovingObjects++; } this->ProcessObject::SetNthInput(2 * index + 1, const_cast(image)); this->Modified(); } } template const typename ImageRegistrationMethodv4:: MovingImageType * ImageRegistrationMethodv4::GetMovingImage( SizeValueType index) const { itkDebugMacro("returning moving image input " << index << " of " << static_cast(this->ProcessObject::GetInput(2 * index + 1))); return static_cast(this->ProcessObject::GetInput(2 * index + 1)); } template void ImageRegistrationMethodv4::SetFixedPointSet( SizeValueType index, const PointSetType * pointSet) { itkDebugMacro("setting fixed point set input " << index << " to " << pointSet); if (pointSet != static_cast(this->ProcessObject::GetInput(2 * index))) { if (!this->ProcessObject::GetInput(2 * index)) { this->m_NumberOfFixedObjects++; } this->ProcessObject::SetNthInput(2 * index, const_cast(pointSet)); this->Modified(); } } template const typename ImageRegistrationMethodv4:: PointSetType * ImageRegistrationMethodv4::GetFixedPointSet( SizeValueType index) const { itkDebugMacro("returning fixed point set input " << index << " of " << static_cast(this->ProcessObject::GetInput(2 * index))); return static_cast(this->ProcessObject::GetInput(2 * index)); } template void ImageRegistrationMethodv4::SetMovingPointSet( SizeValueType index, const PointSetType * pointSet) { itkDebugMacro("setting moving point set input " << index << " to " << pointSet); if (pointSet != static_cast(this->ProcessObject::GetInput(2 * index + 1))) { if (!this->ProcessObject::GetInput(2 * index + 1)) { this->m_NumberOfMovingObjects++; } this->ProcessObject::SetNthInput(2 * index + 1, const_cast(pointSet)); this->Modified(); } } template const typename ImageRegistrationMethodv4:: PointSetType * ImageRegistrationMethodv4::GetMovingPointSet( SizeValueType index) const { itkDebugMacro("returning moving point set input " << index << " of " << static_cast(this->ProcessObject::GetInput(2 * index + 1))); return static_cast(this->ProcessObject::GetInput(2 * index + 1)); } template void ImageRegistrationMethodv4::SetOptimizerWeights( OptimizerWeightsType & weights) { if (weights != this->m_OptimizerWeights) { itkDebugMacro("setting optimizer weights to " << weights); this->m_OptimizerWeights = weights; // Check to see if optimizer weights are identity to avoid unnecessary // computations. this->m_OptimizerWeightsAreIdentity = true; if (this->m_OptimizerWeights.Size() > 0) { using OptimizerWeightsValueType = typename OptimizerWeightsType::ValueType; auto tolerance = static_cast(1e-4); for (SizeValueType i = 0; i < this->m_OptimizerWeights.Size(); ++i) { OptimizerWeightsValueType difference = itk::Math::abs(NumericTraits::OneValue() - this->m_OptimizerWeights[i]); if (difference > tolerance) { this->m_OptimizerWeightsAreIdentity = false; break; } } } this->Modified(); } } template void ImageRegistrationMethodv4:: InitializeRegistrationAtEachLevel(const SizeValueType level) { // To avoid casting to a multimetric several times, we do it once and use it // throughout this function if the current enumerated metric type is MULTI_METRIC typename MultiMetricType::Pointer multiMetric = dynamic_cast(this->m_Metric.GetPointer()); // Sanity checks and find the virtual domain image if (level == 0) { SizeValueType numberOfObjectPairs = static_cast(0.5 * this->GetNumberOfIndexedInputs()); if (numberOfObjectPairs == 0) { itkExceptionMacro("There are no input objects."); } if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC) { this->m_NumberOfMetrics = multiMetric->GetNumberOfMetrics(); if (this->m_NumberOfMetrics != numberOfObjectPairs) { itkExceptionMacro("Mismatch between number of image pairs and the number of metrics."); } } else { this->m_NumberOfMetrics = 1; } // The number of image pairs also includes nullptr image pairs for the point set // metrics if (this->m_NumberOfFixedObjects != this->m_NumberOfMovingObjects) { itkExceptionMacro("The number of fixed and moving images is not equal."); } } if (!this->m_Optimizer) { itkExceptionMacro("The optimizer is not present."); } if (!this->m_Metric) { itkExceptionMacro("The metric is not present."); } auto * movingInitialTransform = const_cast(this->GetMovingInitialTransform()); auto * fixedInitialTransform = const_cast(this->GetFixedInitialTransform()); this->m_CurrentIteration = 0; this->m_CurrentMetricValue = 0.0; this->m_CurrentConvergenceValue = 0.0; this->m_IsConverged = false; this->InvokeEvent(MultiResolutionIterationEvent()); // For each level, we adapt the current transform. For many transforms, e.g. // affine, the base transform adaptor does not do anything. However, in the // case of other transforms, e.g. the b-spline and displacement field transforms // the fixed parameters are changed to reflect an increase in transform resolution. // This could involve increasing the mesh size of the B-spline transform or // increase the resolution of the displacement field. if (this->m_TransformParametersAdaptorsPerLevel[level]) { this->m_TransformParametersAdaptorsPerLevel[level]->SetTransform(this->m_OutputTransform); this->m_TransformParametersAdaptorsPerLevel[level]->AdaptTransformParameters(); } // Set-up the composite transform at initialization // Also, find the virtual domain image if (level == 0) { this->m_CompositeTransform->ClearTransformQueue(); // Since we cannot instantiate a null object from an abstract class, we need to initialize the moving // initial transform as an identity transform. // Nevertheless, we do not need add this transform to the composite transform when it is only an // identity transform. Simply by not setting that, we can save lots of time in jacobian computations // of the composite transform since we can avoid some matrix multiplications. // Skip adding an IdentityTransform to the m_CompositeTransform if (movingInitialTransform != nullptr && std::string(movingInitialTransform->GetNameOfClass()) != std::string("IdentityTransform")) { this->m_CompositeTransform->AddTransform(movingInitialTransform); } // If the moving initial transform is a composite transform, unroll // it into m_CompositeTransform. this->m_CompositeTransform->FlattenTransformQueue(); if (this->m_InitializeCenterOfLinearOutputTransform) { this->InitializeCenterOfLinearOutputTransform(); } this->m_CompositeTransform->AddTransform(this->m_OutputTransform); if (this->m_OptimizerWeights.Size() > 0) { this->m_Optimizer->SetWeights(this->m_OptimizerWeights); } // Get index of first image metric this->m_FirstImageMetricIndex = -1; if (this->m_NumberOfMetrics == 1 && this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC) { this->m_FirstImageMetricIndex = 0; } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC) { for (SizeValueType n = 0; n < this->m_NumberOfMetrics; ++n) { if (multiMetric->GetMetricQueue()[n]->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC) { this->m_FirstImageMetricIndex = n; break; } } } { VirtualImageBaseConstPointer virtualDomainBaseImage = this->GetCurrentLevelVirtualDomainImage(); if (virtualDomainBaseImage.IsNull() && this->m_FirstImageMetricIndex >= 0) { virtualDomainBaseImage = this->GetFixedImage(this->m_FirstImageMetricIndex); } this->m_VirtualDomainImage = VirtualImageType::New(); this->m_VirtualDomainImage->CopyInformation(virtualDomainBaseImage); this->m_VirtualDomainImage->SetRegions(virtualDomainBaseImage->GetLargestPossibleRegion()); this->m_VirtualDomainImage->Allocate(); } this->m_FixedImageMasks.clear(); this->m_FixedImageMasks.resize(this->m_NumberOfMetrics); this->m_MovingImageMasks.clear(); this->m_MovingImageMasks.resize(this->m_NumberOfMetrics); for (SizeValueType n = 0; n < this->m_NumberOfMetrics; ++n) { this->m_FixedImageMasks[n] = nullptr; this->m_MovingImageMasks[n] = nullptr; if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC || (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC && multiMetric->GetMetricQueue()[n]->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC)) { if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC) { this->m_FixedImageMasks[n] = dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer())->GetFixedImageMask(); this->m_MovingImageMasks[n] = dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer())->GetMovingImageMask(); } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC) { this->m_FixedImageMasks[n] = dynamic_cast(this->m_Metric.GetPointer())->GetFixedImageMask(); this->m_MovingImageMasks[n] = dynamic_cast(this->m_Metric.GetPointer())->GetMovingImageMask(); } else { itkExceptionMacro("Invalid metric type."); } } } } this->m_CompositeTransform->SetOnlyMostRecentTransformToOptimizeOn(); // At each resolution and for each image pair (assuming an image metric), we // 1. subsample the reference domain (typically the fixed image) and/or // 2. smooth the fixed and moving images. typename VirtualImageType::Pointer currentLevelVirtualDomainImage = nullptr; if (this->m_VirtualDomainImage.IsNotNull()) { auto shrinkFilter = ShrinkFilterType::New(); shrinkFilter->SetShrinkFactors(this->m_ShrinkFactorsPerLevel[level]); shrinkFilter->SetInput(this->m_VirtualDomainImage); currentLevelVirtualDomainImage = shrinkFilter->GetOutput(); currentLevelVirtualDomainImage->Update(); } else { itkExceptionMacro("A virtual domain image is not found. It should be specified in one of the metrics."); } if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC) { if (fixedInitialTransform) { multiMetric->SetFixedTransform(fixedInitialTransform); } else { using IdentityTransformType = IdentityTransform; auto defaultFixedInitialTransform = IdentityTransformType::New(); multiMetric->SetFixedTransform(defaultFixedInitialTransform); } multiMetric->SetMovingTransform(this->m_CompositeTransform); if (currentLevelVirtualDomainImage.IsNotNull()) { multiMetric->SetVirtualDomainFromImage(currentLevelVirtualDomainImage); } for (SizeValueType n = 0; n < multiMetric->GetNumberOfMetrics(); ++n) { if (multiMetric->GetMetricQueue()[n]->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC) { if (currentLevelVirtualDomainImage.IsNotNull()) { dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer()) ->SetVirtualDomainFromImage(currentLevelVirtualDomainImage); } else { itkExceptionMacro("Virtual domain image is not specified."); } } else if (multiMetric->GetMetricQueue()[n]->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::POINT_SET_METRIC) { if (currentLevelVirtualDomainImage.IsNotNull()) { // This casting is a hack as all the metrics should be coordinated such that // they have identical virtual image types. using CasterType = CastImageFilter; auto caster = CasterType::New(); caster->SetInput(currentLevelVirtualDomainImage); caster->Update(); dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer()) ->SetVirtualDomainFromImage(caster->GetOutput()); } } } } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC) { typename ImageMetricType::Pointer imageMetric = dynamic_cast(this->m_Metric.GetPointer()); if (fixedInitialTransform) { imageMetric->SetFixedTransform(fixedInitialTransform); } else { using IdentityTransformType = IdentityTransform; auto defaultFixedInitialTransform = IdentityTransformType::New(); imageMetric->SetFixedTransform(defaultFixedInitialTransform); } imageMetric->SetMovingTransform(this->m_CompositeTransform); if (currentLevelVirtualDomainImage.IsNotNull()) { imageMetric->SetVirtualDomainFromImage(currentLevelVirtualDomainImage); } } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::POINT_SET_METRIC) { typename PointSetMetricType::Pointer pointSetMetric = dynamic_cast(this->m_Metric.GetPointer()); if (fixedInitialTransform) { pointSetMetric->SetFixedTransform(fixedInitialTransform); } else { using IdentityTransformType = IdentityTransform; auto defaultFixedInitialTransform = IdentityTransformType::New(); pointSetMetric->SetFixedTransform(defaultFixedInitialTransform); } pointSetMetric->SetMovingTransform(this->m_CompositeTransform); if (currentLevelVirtualDomainImage.IsNotNull()) { // This casting is a hack as all the metrics should be coordinated such that // they have identical virtual image types. using CasterType = CastImageFilter; auto caster = CasterType::New(); caster->SetInput(currentLevelVirtualDomainImage); caster->Update(); pointSetMetric->SetVirtualDomainFromImage(caster->GetOutput()); } } else { itkExceptionMacro("Invalid metric conversion."); } // We update the fixed and moving images for the image metrics and // the fixed and moving point sets for the point set metrics. Note // that we set the point sets here just like we set the images. // Although this isn't necessary, we want to leave the option for // changing the point sets per level. this->m_FixedSmoothImages.clear(); this->m_FixedSmoothImages.resize(this->m_NumberOfMetrics); this->m_MovingSmoothImages.clear(); this->m_MovingSmoothImages.resize(this->m_NumberOfMetrics); this->m_FixedPointSets.clear(); this->m_FixedPointSets.resize(this->m_NumberOfMetrics); this->m_MovingPointSets.clear(); this->m_MovingPointSets.resize(this->m_NumberOfMetrics); for (SizeValueType n = 0; n < this->m_NumberOfMetrics; ++n) { this->m_FixedSmoothImages[n] = nullptr; this->m_MovingSmoothImages[n] = nullptr; this->m_FixedPointSets[n] = nullptr; this->m_MovingPointSets[n] = nullptr; if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC || (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC && multiMetric->GetMetricQueue()[n]->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC)) { if (this->m_SmoothingSigmasPerLevel[level] > 0) { using FixedImageSmoothingFilterType = SmoothingRecursiveGaussianImageFilter; typename FixedImageSmoothingFilterType::Pointer fixedImageSmoothingFilter = FixedImageSmoothingFilterType::New(); typename FixedImageSmoothingFilterType::SigmaArrayType fixedImageSigmaArray( this->m_SmoothingSigmasPerLevel[level]); if (!this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits) { auto & fixedSpacing = this->GetFixedImage(n)->GetSpacing(); for (unsigned int i = 0; i < fixedImageSigmaArray.Size(); ++i) { fixedImageSigmaArray[i] *= fixedSpacing[i]; } } fixedImageSmoothingFilter->SetSigmaArray(fixedImageSigmaArray); fixedImageSmoothingFilter->SetInput(this->GetFixedImage(n)); this->m_FixedSmoothImages[n] = fixedImageSmoothingFilter->GetOutput(); fixedImageSmoothingFilter->Update(); fixedImageSmoothingFilter->GetOutput()->DisconnectPipeline(); using MovingImageSmoothingFilterType = SmoothingRecursiveGaussianImageFilter; typename MovingImageSmoothingFilterType::Pointer movingImageSmoothingFilter = MovingImageSmoothingFilterType::New(); typename MovingImageSmoothingFilterType::SigmaArrayType movingImageSigmaArray( this->m_SmoothingSigmasPerLevel[level]); if (!this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits) { auto & movingSpacing = this->GetMovingImage(n)->GetSpacing(); for (unsigned int i = 0; i < movingImageSigmaArray.Size(); ++i) { movingImageSigmaArray[i] *= movingSpacing[i]; } } movingImageSmoothingFilter->SetSigmaArray(movingImageSigmaArray); movingImageSmoothingFilter->SetInput(this->GetMovingImage(n)); this->m_MovingSmoothImages[n] = movingImageSmoothingFilter->GetOutput(); movingImageSmoothingFilter->Update(); movingImageSmoothingFilter->GetOutput()->DisconnectPipeline(); } else { this->m_MovingSmoothImages[n] = this->GetMovingImage(n); this->m_FixedSmoothImages[n] = this->GetFixedImage(n); } // Update the image metric if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC) { multiMetric->GetMetricQueue()[n]->SetFixedObject(this->m_FixedSmoothImages[n]); multiMetric->GetMetricQueue()[n]->SetMovingObject(this->m_MovingSmoothImages[n]); dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer()) ->SetFixedImageMask(this->m_FixedImageMasks[n]); dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer()) ->SetMovingImageMask(this->m_MovingImageMasks[n]); } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC) { this->m_Metric->SetFixedObject(this->m_FixedSmoothImages[n]); this->m_Metric->SetMovingObject(this->m_MovingSmoothImages[n]); dynamic_cast(this->m_Metric.GetPointer())->SetFixedImageMask(this->m_FixedImageMasks[n]); dynamic_cast(this->m_Metric.GetPointer())->SetMovingImageMask(this->m_MovingImageMasks[n]); } else { itkExceptionMacro("Invalid metric type."); } } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::POINT_SET_METRIC || (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC && multiMetric->GetMetricQueue()[n]->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::POINT_SET_METRIC)) { this->m_FixedPointSets[n] = this->GetFixedPointSet(n); this->m_MovingPointSets[n] = this->GetMovingPointSet(n); // Update the point set metric if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::MULTI_METRIC) { multiMetric->GetMetricQueue()[n]->SetFixedObject(this->GetFixedPointSet(n)); multiMetric->GetMetricQueue()[n]->SetMovingObject(this->GetMovingPointSet(n)); } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::POINT_SET_METRIC) { this->m_Metric->SetFixedObject(this->GetFixedPointSet(n)); this->m_Metric->SetMovingObject(this->GetMovingPointSet(n)); } else { itkExceptionMacro("Invalid metric type."); } } else { itkExceptionMacro("Invalid metric type."); } } if (this->m_MetricSamplingStrategy != MetricSamplingStrategyEnum::NONE) { this->SetMetricSamplePoints(); } // Update the optimizer this->m_Optimizer->SetMetric(this->m_Metric); if (this->m_Optimizer->CanUseScales() && ((this->m_Optimizer->GetScales()).Size() != this->m_OutputTransform->GetNumberOfLocalParameters())) { using ScalesType = typename OptimizerType::ScalesType; ScalesType scales; scales.SetSize(this->m_OutputTransform->GetNumberOfLocalParameters()); scales.Fill(NumericTraits::OneValue()); this->m_Optimizer->SetScales(scales); } } template void ImageRegistrationMethodv4::AllocateOutputs() { const DecoratedInitialTransformType * decoratedInitialTransform = this->GetInitialTransformInput(); DecoratedOutputTransformType * decoratedOutputTransform = this->GetOutput(); if (decoratedInitialTransform) { const InitialTransformType * initialTransform = decoratedInitialTransform->Get(); if (initialTransform) { if (this->GetInPlace()) { // graft the input to the output which may fail if the types // aren't compatible. decoratedOutputTransform->Graft(decoratedInitialTransform); if (decoratedOutputTransform->Get()) { this->m_OutputTransform = decoratedOutputTransform->GetModifiable(); // This is generally done in the ReleaseInputs methods, // however we do not need it again const_cast(decoratedInitialTransform)->ReleaseData(); // successful in-place grafting itkDebugMacro("inplace allocation of output transform"); return; } } const auto * initialAsOutputTransform = dynamic_cast(initialTransform); if (initialAsOutputTransform) { // Clone performs a deep copy of the parameters and composition this->m_OutputTransform = initialAsOutputTransform->Clone(); decoratedOutputTransform->Set(this->m_OutputTransform); // successful deep copy from initial to output itkDebugMacro("clone copy allocation of output transform"); return; } else { itkExceptionMacro("Unable to convert InitialTransform input to the OutputTransform type"); } } } // fallback allocation and initialization // initialize to identity? what happens if we re-run with optimized values? itkDebugMacro("fallback allocation of output transform"); if (!decoratedOutputTransform->Get()) { // the output decorated component is null, allocate OutputTransformPointer ptr; Self::MakeOutputTransform(ptr); decoratedOutputTransform->Set(ptr); } this->m_OutputTransform = this->GetModifiableTransform(); } template void ImageRegistrationMethodv4::GenerateData() { this->AllocateOutputs(); // Ensure the same seed is used for each update this->m_CurrentRandomSeed = this->m_RandomSeed; for (this->m_CurrentLevel = 0; this->m_CurrentLevel < this->m_NumberOfLevels; this->m_CurrentLevel++) { this->InitializeRegistrationAtEachLevel(this->m_CurrentLevel); this->m_Metric->Initialize(); this->m_Optimizer->StartOptimization(); } } template void ImageRegistrationMethodv4:: SetTransformParametersAdaptorsPerLevel(TransformParametersAdaptorsContainerType & adaptors) { if (this->m_NumberOfLevels != adaptors.size()) { itkExceptionMacro("The number of levels does not equal the number array size."); } else { itkDebugMacro("Setting the transform parameter adaptors."); this->m_TransformParametersAdaptorsPerLevel = adaptors; this->Modified(); } } template const typename ImageRegistrationMethodv4:: TransformParametersAdaptorsContainerType & ImageRegistrationMethodv4:: GetTransformParametersAdaptorsPerLevel() const { return this->m_TransformParametersAdaptorsPerLevel; } template void ImageRegistrationMethodv4::SetNumberOfLevels( const SizeValueType numberOfLevels) { if (this->m_NumberOfLevels != numberOfLevels) { this->m_NumberOfLevels = numberOfLevels; // Set default transform adaptors which don't do anything to the input transform // Similarly, fill in some default values for the shrink factors, smoothing sigmas, // and learning rates. this->m_TransformParametersAdaptorsPerLevel.clear(); for (SizeValueType level = 0; level < this->m_NumberOfLevels; ++level) { this->m_TransformParametersAdaptorsPerLevel.push_back(nullptr); } for (SizeValueType level = 0; level < this->m_NumberOfLevels; ++level) { ShrinkFactorsPerDimensionContainerType shrinkFactors; shrinkFactors.Fill(1); this->SetShrinkFactorsPerDimension(level, shrinkFactors); } this->m_SmoothingSigmasPerLevel.SetSize(this->m_NumberOfLevels); this->m_SmoothingSigmasPerLevel.Fill(1.0); this->m_MetricSamplingPercentagePerLevel.SetSize(this->m_NumberOfLevels); this->m_MetricSamplingPercentagePerLevel.Fill(1.0); this->Modified(); } } template void ImageRegistrationMethodv4::SetMetricSamplePoints() { using VirtualDomainImageType = typename ImageMetricType::VirtualImageType; using VirtualDomainRegionType = typename VirtualDomainImageType::RegionType; const VirtualDomainImageType * virtualImage = nullptr; const FixedImageMaskType * fixedMaskImage = nullptr; SizeValueType numberOfLocalMetrics = 1; typename MultiMetricType::Pointer multiMetric = dynamic_cast(this->m_Metric.GetPointer()); if (multiMetric) { numberOfLocalMetrics = multiMetric->GetNumberOfMetrics(); if (numberOfLocalMetrics < 1) { itkExceptionMacro("Input multi metric should have at least one metric component."); } else { typename ImageMetricType::Pointer firstMetric = dynamic_cast(multiMetric->GetMetricQueue()[0].GetPointer()); if (firstMetric.IsNotNull()) { virtualImage = firstMetric->GetVirtualImage(); fixedMaskImage = firstMetric->GetFixedImageMask(); } else { itkExceptionMacro("Invalid metric conversion."); } } } else { typename ImageMetricType::Pointer singleMetric = dynamic_cast(this->m_Metric.GetPointer()); if (singleMetric.IsNotNull()) { virtualImage = singleMetric->GetVirtualImage(); fixedMaskImage = singleMetric->GetFixedImageMask(); } else { itkExceptionMacro("Invalid metric conversion."); } } const VirtualDomainRegionType & virtualDomainRegion = virtualImage->GetRequestedRegion(); const typename VirtualDomainImageType::SpacingType oneThirdVirtualSpacing = virtualImage->GetSpacing() / 3.0; for (SizeValueType n = 0; n < numberOfLocalMetrics; ++n) { auto samplePointSet = MetricSamplePointSetType::New(); samplePointSet->Initialize(); using SamplePointType = typename MetricSamplePointSetType::PointType; using RandomizerType = Statistics::MersenneTwisterRandomVariateGenerator; auto randomizer = RandomizerType::New(); if (m_ReseedIterator) { randomizer->SetSeed(); } else { randomizer->SetSeed(m_CurrentRandomSeed++); } unsigned long index = 0; switch (this->m_MetricSamplingStrategy) { case MetricSamplingStrategyEnum::REGULAR: { const auto sampleCount = static_cast(std::ceil(1.0 / this->m_MetricSamplingPercentagePerLevel[this->m_CurrentLevel])); unsigned long count = sampleCount; // Start at sampleCount to keep behavior backwards identical, using first element. ImageRegionConstIteratorWithIndex It(virtualImage, virtualDomainRegion); for (It.GoToBegin(); !It.IsAtEnd(); ++It) { if (count == sampleCount) { count = 0; // Reset counter SamplePointType point; virtualImage->TransformIndexToPhysicalPoint(It.GetIndex(), point); // randomly perturb the point within a voxel (approximately) for (SizeValueType d = 0; d < ImageDimension; ++d) { point[d] += randomizer->GetNormalVariate() * oneThirdVirtualSpacing[d]; } if (!fixedMaskImage || fixedMaskImage->IsInsideInWorldSpace(point)) { samplePointSet->SetPoint(index, point); ++index; } } ++count; } break; } case MetricSamplingStrategyEnum::RANDOM: { const unsigned long totalVirtualDomainVoxels = virtualDomainRegion.GetNumberOfPixels(); const auto sampleCount = static_cast(static_cast(totalVirtualDomainVoxels) * this->m_MetricSamplingPercentagePerLevel[this->m_CurrentLevel]); ImageRandomConstIteratorWithIndex ItR(virtualImage, virtualDomainRegion); if (m_ReseedIterator) { ItR.ReinitializeSeed(); } else { ItR.ReinitializeSeed(m_CurrentRandomSeed++); } ItR.SetNumberOfSamples(sampleCount); for (ItR.GoToBegin(); !ItR.IsAtEnd(); ++ItR) { SamplePointType point; virtualImage->TransformIndexToPhysicalPoint(ItR.GetIndex(), point); // randomly perturb the point within a voxel (approximately) for (unsigned int d = 0; d < ImageDimension; ++d) { point[d] += randomizer->GetNormalVariate() * oneThirdVirtualSpacing[d]; } if (!fixedMaskImage || fixedMaskImage->IsInsideInWorldSpace(point)) { samplePointSet->SetPoint(index, point); ++index; } } break; } default: { itkExceptionMacro("Invalid sampling strategy requested."); } } if (multiMetric) { dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer()) ->SetVirtualSampledPointSet(samplePointSet); dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer())->UseSampledPointSetOn(); dynamic_cast(multiMetric->GetMetricQueue()[n].GetPointer())->UseVirtualSampledPointSetOn(); } else { dynamic_cast(this->m_Metric.GetPointer())->SetVirtualSampledPointSet(samplePointSet); dynamic_cast(this->m_Metric.GetPointer())->UseSampledPointSetOn(); dynamic_cast(this->m_Metric.GetPointer())->UseVirtualSampledPointSetOn(); } } } template void ImageRegistrationMethodv4:: InitializeCenterOfLinearOutputTransform() { using MatrixOffsetTransformType = MatrixOffsetTransformBase; auto * matrixOffsetOutputTransform = dynamic_cast(this->GetModifiableTransform()); if (!matrixOffsetOutputTransform) { return; } SizeValueType numberOfTransforms = this->m_CompositeTransform->GetNumberOfTransforms(); if (numberOfTransforms == 0) { return; } typename TTransform::InputPointType center = matrixOffsetOutputTransform->GetCenter(); bool optimalIndexFound = false; for (int i = numberOfTransforms - 1; i >= 0; i--) { auto * matrixOffsetTransform = dynamic_cast(this->m_CompositeTransform->GetNthTransformModifiablePointer(i)); if (matrixOffsetTransform) { center = matrixOffsetTransform->GetCenter(); optimalIndexFound = true; break; } } if (!optimalIndexFound) { return; } else { matrixOffsetOutputTransform->SetCenter(center); } } template typename ImageRegistrationMethodv4:: VirtualImageBaseConstPointer ImageRegistrationMethodv4:: GetCurrentLevelVirtualDomainImage() { // Get virtual domain image VirtualImageBaseConstPointer currentLevelVirtualDomainImage = nullptr; if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::IMAGE_METRIC) { currentLevelVirtualDomainImage = dynamic_cast(this->m_Metric.GetPointer())->GetVirtualImage(); } else if (this->m_Metric->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::POINT_SET_METRIC) { currentLevelVirtualDomainImage = dynamic_cast(this->m_Metric.GetPointer())->GetVirtualImage(); } else { typename MultiMetricType::Pointer multiMetric = dynamic_cast(this->m_Metric.GetPointer()); if (multiMetric->GetMetricQueue()[0]->GetMetricCategory() == ObjectToObjectMetricBaseTemplateEnums::MetricCategory::POINT_SET_METRIC) { currentLevelVirtualDomainImage = dynamic_cast(multiMetric->GetMetricQueue()[0].GetPointer())->GetVirtualImage(); } else { currentLevelVirtualDomainImage = dynamic_cast(multiMetric->GetMetricQueue()[0].GetPointer())->GetVirtualImage(); } } return currentLevelVirtualDomainImage; } template void ImageRegistrationMethodv4::PrintSelf( std::ostream & os, Indent indent) const { using namespace print_helper; Superclass::PrintSelf(os, indent); os << indent << "CurrentLevel: " << static_cast::PrintType>(m_CurrentLevel) << std::endl; os << indent << "NumberOfLevels: " << static_cast::PrintType>(m_NumberOfLevels) << std::endl; os << indent << "CurrentIteration: " << static_cast::PrintType>(m_CurrentIteration) << std::endl; os << indent << "CurrentMetricValue: " << static_cast::PrintType>(m_CurrentMetricValue) << std::endl; os << indent << "CurrentConvergenceValue: " << static_cast::PrintType>(m_CurrentConvergenceValue) << std::endl; os << indent << "IsConverged: " << (m_IsConverged ? "On" : "Off") << std::endl; os << indent << "FixedSmoothImages: " << m_FixedSmoothImages << std::endl; os << indent << "MovingSmoothImages: " << m_MovingSmoothImages << std::endl; os << indent << "FixedImageMasks: " << m_FixedImageMasks << std::endl; os << indent << "MovingImageMasks: " << m_MovingImageMasks << std::endl; itkPrintSelfObjectMacro(VirtualDomainImage); os << indent << "FixedPointSets: " << m_FixedPointSets << std::endl; os << indent << "MovingPointSets: " << m_MovingPointSets << std::endl; os << indent << "NumberOfFixedObjects: " << static_cast::PrintType>(m_NumberOfFixedObjects) << std::endl; os << indent << "NumberOfMovingObjects: " << static_cast::PrintType>(m_NumberOfMovingObjects) << std::endl; itkPrintSelfObjectMacro(Optimizer); os << indent << "OptimizerWeights: " << static_cast::PrintType>(m_OptimizerWeights) << std::endl; os << indent << "OptimizerWeightsAreIdentity: " << (m_OptimizerWeightsAreIdentity ? "On" : "Off") << std::endl; itkPrintSelfObjectMacro(Metric); os << indent << "MetricSamplingStrategy: " << m_MetricSamplingStrategy << std::endl; os << indent << "MetricSamplingPercentagePerLevel: " << m_MetricSamplingPercentagePerLevel << std::endl; os << indent << "NumberOfMetrics: " << static_cast::PrintType>(m_NumberOfMetrics) << std::endl; os << indent << "FirstImageMetricIndex: " << m_FirstImageMetricIndex << std::endl; os << indent << "ShrinkFactorsPerLevel: " << m_ShrinkFactorsPerLevel << std::endl; os << indent << "SmoothingSigmasPerLevel: " << m_SmoothingSigmasPerLevel << std::endl; os << indent << "SmoothingSigmasAreSpecifiedInPhysicalUnits: " << (m_SmoothingSigmasAreSpecifiedInPhysicalUnits ? "On" : "Off") << std::endl; os << indent << "ReseedIterator: " << (m_ReseedIterator ? "On" : "Off") << std::endl; os << indent << "RandomSeed: " << m_RandomSeed << std::endl; os << indent << "CurrentRandomSeed: " << m_CurrentRandomSeed << std::endl; os << indent << "TransformParametersAdaptorsPerLevel: "; for (const auto & val : m_TransformParametersAdaptorsPerLevel) { os << indent.GetNextIndent() << val << " "; } os << std::endl; itkPrintSelfObjectMacro(CompositeTransform); itkPrintSelfObjectMacro(OutputTransform); os << indent << "InPlace: " << (m_InPlace ? "On" : "Off") << std::endl; os << indent << "InitializeCenterOfLinearOutputTransform: " << (m_InitializeCenterOfLinearOutputTransform ? "On" : "Off") << std::endl; } template typename ImageRegistrationMethodv4:: DecoratedOutputTransformType * ImageRegistrationMethodv4::GetOutput() { return static_cast(this->ProcessObject::GetOutput(0)); } template const typename ImageRegistrationMethodv4:: DecoratedOutputTransformType * ImageRegistrationMethodv4::GetOutput() const { return static_cast(this->ProcessObject::GetOutput(0)); } template typename ImageRegistrationMethodv4:: OutputTransformType * ImageRegistrationMethodv4::GetModifiableTransform() { DecoratedOutputTransformType * temp = this->GetOutput(); // required outputs of process object should always exits itkAssertInDebugAndIgnoreInReleaseMacro(temp != nullptr); return temp->GetModifiable(); } template const typename ImageRegistrationMethodv4:: OutputTransformType * ImageRegistrationMethodv4::GetTransform() const { const DecoratedOutputTransformType * temp = this->GetOutput(); // required outputs of process object should always exits itkAssertInDebugAndIgnoreInReleaseMacro(temp != nullptr); return temp->Get(); } template DataObject::Pointer ImageRegistrationMethodv4::MakeOutput( DataObjectPointerArraySizeType output) { if (output > 0) { itkExceptionMacro("MakeOutput request for an output number larger than the expected number of outputs."); } OutputTransformPointer ptr; Self::MakeOutputTransform(ptr); DecoratedOutputTransformPointer transformDecorator = DecoratedOutputTransformType::New(); transformDecorator->Set(ptr); return transformDecorator.GetPointer(); } template void ImageRegistrationMethodv4:: MetricSamplingReinitializeSeed() { if (!m_ReseedIterator) { m_ReseedIterator = true; this->Modified(); } } template void ImageRegistrationMethodv4:: MetricSamplingReinitializeSeed(int seed) { if (m_ReseedIterator || m_RandomSeed != seed) { m_ReseedIterator = false; m_RandomSeed = seed; m_CurrentRandomSeed = seed; this->Modified(); } } template void ImageRegistrationMethodv4::SetMetricSamplingPercentage( const RealType samplingPercentage) { MetricSamplingPercentageArrayType samplingPercentagePerLevel; samplingPercentagePerLevel.SetSize(this->m_NumberOfLevels); samplingPercentagePerLevel.Fill(samplingPercentage); this->SetMetricSamplingPercentagePerLevel(samplingPercentagePerLevel); } template void ImageRegistrationMethodv4:: SetMetricSamplingPercentagePerLevel(const MetricSamplingPercentageArrayType & samplingPercentages) { if (this->m_MetricSamplingPercentagePerLevel != samplingPercentages) { for (typename MetricSamplingPercentageArrayType::const_iterator it = samplingPercentages.begin(); it != samplingPercentages.end(); it++) { if (*it <= 0.0 || *it > 1.0) { itkExceptionMacro("sampling percentage outside expected (0,1] range"); } } this->m_MetricSamplingPercentagePerLevel = samplingPercentages; this->Modified(); } } } // end namespace itk #endif