/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkMultiResolutionPDEDeformableRegistration_hxx #define itkMultiResolutionPDEDeformableRegistration_hxx #include "itkRecursiveGaussianImageFilter.h" #include "itkRecursiveMultiResolutionPyramidImageFilter.h" #include "itkImageRegionIterator.h" #include "itkMath.h" namespace itk { template MultiResolutionPDEDeformableRegistration::MultiResolutionPDEDeformableRegistration() { this->SetNumberOfRequiredInputs(2); // Primary input is optional in this filter this->RemoveRequiredInputName("Primary"); auto registrator = DefaultRegistrationType::New(); m_RegistrationFilter = registrator.GetPointer(); m_MovingImagePyramid = MovingImagePyramidType::New(); m_FixedImagePyramid = FixedImagePyramidType::New(); m_FieldExpander = FieldExpanderType::New(); m_InitialDisplacementField = nullptr; m_NumberOfLevels = 3; m_NumberOfIterations.SetSize(m_NumberOfLevels); m_FixedImagePyramid->SetNumberOfLevels(m_NumberOfLevels); m_MovingImagePyramid->SetNumberOfLevels(m_NumberOfLevels); unsigned int ilevel; for (ilevel = 0; ilevel < m_NumberOfLevels; ++ilevel) { m_NumberOfIterations[ilevel] = 10; } m_CurrentLevel = 0; m_StopRegistrationFlag = false; } template void MultiResolutionPDEDeformableRegistration::SetMovingImage(const MovingImageType * ptr) { this->ProcessObject::SetNthInput(2, const_cast(ptr)); } template const typename MultiResolutionPDEDeformableRegistration::MovingImageType * MultiResolutionPDEDeformableRegistration::GetMovingImage() const { return dynamic_cast(this->ProcessObject::GetInput(2)); } template void MultiResolutionPDEDeformableRegistration::SetFixedImage(const FixedImageType * ptr) { this->ProcessObject::SetNthInput(1, const_cast(ptr)); } template const typename MultiResolutionPDEDeformableRegistration::FixedImageType * MultiResolutionPDEDeformableRegistration::GetFixedImage() const { return dynamic_cast(this->ProcessObject::GetInput(1)); } template std::vector>::size_type MultiResolutionPDEDeformableRegistration::GetNumberOfValidRequiredInputs() const { typename std::vector>::size_type num = 0; if (this->GetFixedImage()) { ++num; } if (this->GetMovingImage()) { ++num; } return num; } template void MultiResolutionPDEDeformableRegistration::SetNumberOfLevels(unsigned int num) { if (m_NumberOfLevels != num) { this->Modified(); m_NumberOfLevels = num; m_NumberOfIterations.SetSize(m_NumberOfLevels); } if (m_MovingImagePyramid && m_MovingImagePyramid->GetNumberOfLevels() != num) { m_MovingImagePyramid->SetNumberOfLevels(m_NumberOfLevels); } if (m_FixedImagePyramid && m_FixedImagePyramid->GetNumberOfLevels() != num) { m_FixedImagePyramid->SetNumberOfLevels(m_NumberOfLevels); } } template void MultiResolutionPDEDeformableRegistration::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); itkPrintSelfObjectMacro(RegistrationFilter); itkPrintSelfObjectMacro(MovingImagePyramid); itkPrintSelfObjectMacro(FixedImagePyramid); itkPrintSelfObjectMacro(FieldExpander); itkPrintSelfObjectMacro(InitialDisplacementField); os << indent << "NumberOfLevels: " << m_NumberOfLevels << std::endl; os << indent << "CurrentLevel: " << m_CurrentLevel << std::endl; os << indent << "NumberOfIterations: ["; unsigned int ilevel; for (ilevel = 0; ilevel < m_NumberOfLevels - 1; ++ilevel) { os << m_NumberOfIterations[ilevel] << ", "; } os << m_NumberOfIterations[ilevel] << ']' << std::endl; os << indent << "StopRegistrationFlag: " << (m_StopRegistrationFlag ? "On" : "Off") << std::endl; } template void MultiResolutionPDEDeformableRegistration::GenerateData() { // Check for nullptr images and pointers MovingImageConstPointer movingImage = this->GetMovingImage(); FixedImageConstPointer fixedImage = this->GetFixedImage(); if (!movingImage || !fixedImage) { itkExceptionMacro("Fixed and/or moving image not set"); } if (!m_MovingImagePyramid || !m_FixedImagePyramid) { itkExceptionMacro("Fixed and/or moving pyramid not set"); } if (!m_RegistrationFilter) { itkExceptionMacro("Registration filter not set"); } if (this->m_InitialDisplacementField && this->GetInput(0)) { itkExceptionMacro("Only one initial deformation can be given. " << "SetInitialDisplacementField should not be used in " << "cunjunction with SetArbitraryInitialDisplacementField " << "or SetInput."); } // as per suggestion in this bug report: // https://public.kitware.com/Bug/view.php?id=3590 // this should allow input images to be released, since // they are no longer needed after generating the image pyramid this->RestoreInputReleaseDataFlags(); // Create the image pyramids. m_MovingImagePyramid->SetInput(movingImage); m_MovingImagePyramid->UpdateLargestPossibleRegion(); m_FixedImagePyramid->SetInput(fixedImage); m_FixedImagePyramid->UpdateLargestPossibleRegion(); // Initializations m_CurrentLevel = 0; m_StopRegistrationFlag = false; unsigned int movingLevel = std::min(static_cast(m_CurrentLevel), static_cast(m_MovingImagePyramid->GetNumberOfLevels())); unsigned int fixedLevel = std::min(static_cast(m_CurrentLevel), static_cast(m_FixedImagePyramid->GetNumberOfLevels())); DisplacementFieldPointer tempField = nullptr; DisplacementFieldPointer inputPtr = const_cast(this->GetInput(0)); if (this->m_InitialDisplacementField) { tempField = this->m_InitialDisplacementField; } else if (inputPtr) { // Arbitrary initial deformation field is set. // smooth it and resample // First smooth it tempField = inputPtr; using GaussianFilterType = RecursiveGaussianImageFilter; auto smoother = GaussianFilterType::New(); for (unsigned int dim = 0; dim < DisplacementFieldType::ImageDimension; ++dim) { // sigma accounts for the subsampling of the pyramid double sigma = 0.5 * static_cast(m_FixedImagePyramid->GetSchedule()[fixedLevel][dim]); // but also for a possible discrepancy in the spacing sigma *= fixedImage->GetSpacing()[dim] / inputPtr->GetSpacing()[dim]; smoother->SetInput(tempField); smoother->SetSigma(sigma); smoother->SetDirection(dim); smoother->Update(); tempField = smoother->GetOutput(); tempField->DisconnectPipeline(); } // Now resample m_FieldExpander->SetInput(tempField); typename FloatImageType::Pointer fi = m_FixedImagePyramid->GetOutput(fixedLevel); m_FieldExpander->SetSize(fi->GetLargestPossibleRegion().GetSize()); m_FieldExpander->SetOutputStartIndex(fi->GetLargestPossibleRegion().GetIndex()); m_FieldExpander->SetOutputOrigin(fi->GetOrigin()); m_FieldExpander->SetOutputSpacing(fi->GetSpacing()); m_FieldExpander->SetOutputDirection(fi->GetDirection()); m_FieldExpander->UpdateLargestPossibleRegion(); m_FieldExpander->SetInput(nullptr); tempField = m_FieldExpander->GetOutput(); tempField->DisconnectPipeline(); } bool lastShrinkFactorsAllOnes = false; while (!this->Halt()) { if (tempField.IsNull()) { m_RegistrationFilter->SetInitialDisplacementField(nullptr); } else { // Resample the field to be the same size as the fixed image // at the current level m_FieldExpander->SetInput(tempField); typename FloatImageType::Pointer fi = m_FixedImagePyramid->GetOutput(fixedLevel); m_FieldExpander->SetSize(fi->GetLargestPossibleRegion().GetSize()); m_FieldExpander->SetOutputStartIndex(fi->GetLargestPossibleRegion().GetIndex()); m_FieldExpander->SetOutputOrigin(fi->GetOrigin()); m_FieldExpander->SetOutputSpacing(fi->GetSpacing()); m_FieldExpander->SetOutputDirection(fi->GetDirection()); m_FieldExpander->UpdateLargestPossibleRegion(); m_FieldExpander->SetInput(nullptr); tempField = m_FieldExpander->GetOutput(); tempField->DisconnectPipeline(); m_RegistrationFilter->SetInitialDisplacementField(tempField); } // setup registration filter and pyramids m_RegistrationFilter->SetMovingImage(m_MovingImagePyramid->GetOutput(movingLevel)); m_RegistrationFilter->SetFixedImage(m_FixedImagePyramid->GetOutput(fixedLevel)); m_RegistrationFilter->SetNumberOfIterations(m_NumberOfIterations[m_CurrentLevel]); // cache shrink factors for computing the next expand factors. lastShrinkFactorsAllOnes = true; for (unsigned int idim = 0; idim < ImageDimension; ++idim) { if (m_FixedImagePyramid->GetSchedule()[fixedLevel][idim] > 1) { lastShrinkFactorsAllOnes = false; break; } } // compute new deformation field m_RegistrationFilter->UpdateLargestPossibleRegion(); tempField = m_RegistrationFilter->GetOutput(); tempField->DisconnectPipeline(); // Increment level counter. ++m_CurrentLevel; movingLevel = std::min(static_cast(m_CurrentLevel), static_cast(m_MovingImagePyramid->GetNumberOfLevels())); fixedLevel = std::min(static_cast(m_CurrentLevel), static_cast(m_FixedImagePyramid->GetNumberOfLevels())); // Invoke an iteration event. this->InvokeEvent(MultiResolutionIterationEvent()); // We can release data from pyramid which are no longer required. if (movingLevel > 0) { m_MovingImagePyramid->GetOutput(movingLevel - 1)->ReleaseData(); } if (fixedLevel > 0) { m_FixedImagePyramid->GetOutput(fixedLevel - 1)->ReleaseData(); } } // while not Halt() if (!lastShrinkFactorsAllOnes) { // Some of the last shrink factors are not one // graft the output of the expander filter to // to output of this filter // resample the field to the same size as the fixed image m_FieldExpander->SetInput(tempField); m_FieldExpander->SetSize(fixedImage->GetLargestPossibleRegion().GetSize()); m_FieldExpander->SetOutputStartIndex(fixedImage->GetLargestPossibleRegion().GetIndex()); m_FieldExpander->SetOutputOrigin(fixedImage->GetOrigin()); m_FieldExpander->SetOutputSpacing(fixedImage->GetSpacing()); m_FieldExpander->SetOutputDirection(fixedImage->GetDirection()); m_FieldExpander->UpdateLargestPossibleRegion(); this->GraftOutput(m_FieldExpander->GetOutput()); } else { // all the last shrink factors are all ones // graft the output of registration filter to // to output of this filter this->GraftOutput(tempField); } // Release memory m_FieldExpander->SetInput(nullptr); m_FieldExpander->GetOutput()->ReleaseData(); m_RegistrationFilter->SetInput(nullptr); m_RegistrationFilter->GetOutput()->ReleaseData(); } template void MultiResolutionPDEDeformableRegistration::StopRegistration() { m_RegistrationFilter->StopRegistration(); m_StopRegistrationFlag = true; } template bool MultiResolutionPDEDeformableRegistration::Halt() { // Halt the registration after the user-specified number of levels if (m_NumberOfLevels != 0) { this->UpdateProgress(static_cast(m_CurrentLevel) / static_cast(m_NumberOfLevels)); } if (m_CurrentLevel >= m_NumberOfLevels) { return true; } if (m_StopRegistrationFlag) { return true; } else { return false; } } template void MultiResolutionPDEDeformableRegistration::GenerateOutputInformation() { typename DataObject::Pointer output; if (this->GetInput(0)) { // Initial deformation field is set. // Copy information from initial field. this->Superclass::GenerateOutputInformation(); } else if (this->GetFixedImage()) { // Initial deformation field is not set. // Copy information from the fixed image. for (unsigned int idx = 0; idx < this->GetNumberOfIndexedOutputs(); ++idx) { output = this->GetOutput(idx); if (output) { output->CopyInformation(this->GetFixedImage()); } } } } template void MultiResolutionPDEDeformableRegistration::GenerateInputRequestedRegion() { // call the superclass's implementation Superclass::GenerateInputRequestedRegion(); // request the largest possible region for the moving image MovingImagePointer movingPtr = const_cast(this->GetMovingImage()); if (movingPtr) { movingPtr->SetRequestedRegionToLargestPossibleRegion(); } // just propagate up the output requested region for // the fixed image and initial deformation field. DisplacementFieldPointer inputPtr = const_cast(this->GetInput()); DisplacementFieldPointer outputPtr = this->GetOutput(); FixedImagePointer fixedPtr = const_cast(this->GetFixedImage()); if (inputPtr) { inputPtr->SetRequestedRegion(outputPtr->GetRequestedRegion()); } if (fixedPtr) { fixedPtr->SetRequestedRegion(outputPtr->GetRequestedRegion()); } } template void MultiResolutionPDEDeformableRegistration::EnlargeOutputRequestedRegion(DataObject * ptr) { // call the superclass's implementation Superclass::EnlargeOutputRequestedRegion(ptr); // set the output requested region to largest possible. DisplacementFieldType * outputPtr; outputPtr = dynamic_cast(ptr); if (outputPtr) { outputPtr->SetRequestedRegionToLargestPossibleRegion(); } } } // end namespace itk #endif