/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkESMDemonsRegistrationFunction_hxx #define itkESMDemonsRegistrationFunction_hxx #include "itkMath.h" namespace itk { template ESMDemonsRegistrationFunction::ESMDemonsRegistrationFunction() { RadiusType r; unsigned int j; for (j = 0; j < ImageDimension; ++j) { r[j] = 0; } this->SetRadius(r); m_TimeStep = 1.0; m_DenominatorThreshold = 1e-9; m_IntensityDifferenceThreshold = 0.001; m_MaximumUpdateStepLength = 0.5; this->SetMovingImage(nullptr); this->SetFixedImage(nullptr); m_FixedImageSpacing.Fill(1.0); m_FixedImageOrigin.Fill(0.0); m_FixedImageDirection.SetIdentity(); m_Normalizer = 0.0; m_FixedImageGradientCalculator = GradientCalculatorType::New(); // Gradient orientation will be taken care of explicitly m_FixedImageGradientCalculator->UseImageDirectionOff(); m_MappedMovingImageGradientCalculator = MovingImageGradientCalculatorType::New(); // Gradient orientation will be taken care of explicitly m_MappedMovingImageGradientCalculator->UseImageDirectionOff(); this->m_UseGradientType = GradientEnum::Symmetric; auto interp = DefaultInterpolatorType::New(); m_MovingImageInterpolator = itkDynamicCastInDebugMode(interp.GetPointer()); m_MovingImageWarper = WarperType::New(); m_MovingImageWarper->SetInterpolator(m_MovingImageInterpolator); m_MovingImageWarper->SetEdgePaddingValue(NumericTraits::max()); m_MovingImageWarperOutput = nullptr; m_Metric = NumericTraits::max(); m_SumOfSquaredDifference = 0.0; m_NumberOfPixelsProcessed = 0L; m_RMSChange = NumericTraits::max(); m_SumOfSquaredChange = 0.0; } template void ESMDemonsRegistrationFunction::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); os << indent << "UseGradientType: " << m_UseGradientType << std::endl; os << indent << "MaximumUpdateStepLength: " << m_MaximumUpdateStepLength << std::endl; itkPrintSelfObjectMacro(MovingImageInterpolator); itkPrintSelfObjectMacro(FixedImageGradientCalculator); itkPrintSelfObjectMacro(MappedMovingImageGradientCalculator); os << indent << "DenominatorThreshold: " << m_DenominatorThreshold << std::endl; os << indent << "IntensityDifferenceThreshold: " << m_IntensityDifferenceThreshold << std::endl; os << indent << "Metric: " << m_Metric << std::endl; os << indent << "SumOfSquaredDifference: " << m_SumOfSquaredDifference << std::endl; os << indent << "NumberOfPixelsProcessed: " << m_NumberOfPixelsProcessed << std::endl; os << indent << "RMSChange: " << m_RMSChange << std::endl; os << indent << "SumOfSquaredChange: " << m_SumOfSquaredChange << std::endl; } template void ESMDemonsRegistrationFunction::SetIntensityDifferenceThreshold( double threshold) { m_IntensityDifferenceThreshold = threshold; } template double ESMDemonsRegistrationFunction::GetIntensityDifferenceThreshold() const { return m_IntensityDifferenceThreshold; } template void ESMDemonsRegistrationFunction::InitializeIteration() { if (!this->GetMovingImage() || !this->GetFixedImage() || !m_MovingImageInterpolator) { itkExceptionMacro("MovingImage, FixedImage and/or Interpolator not set"); } // cache fixed image information m_FixedImageOrigin = this->GetFixedImage()->GetOrigin(); m_FixedImageSpacing = this->GetFixedImage()->GetSpacing(); m_FixedImageDirection = this->GetFixedImage()->GetDirection(); // compute the normalizer if (m_MaximumUpdateStepLength > 0.0) { m_Normalizer = 0.0; for (unsigned int k = 0; k < ImageDimension; ++k) { m_Normalizer += m_FixedImageSpacing[k] * m_FixedImageSpacing[k]; } m_Normalizer *= m_MaximumUpdateStepLength * m_MaximumUpdateStepLength / static_cast(ImageDimension); } else { // set it to minus one to denote a special case // ( unrestricted update length ) m_Normalizer = -1.0; } // setup gradient calculator m_FixedImageGradientCalculator->SetInputImage(this->GetFixedImage()); m_MappedMovingImageGradientCalculator->SetInputImage(this->GetMovingImage()); // Compute warped moving image m_MovingImageWarper->SetOutputOrigin(this->m_FixedImageOrigin); m_MovingImageWarper->SetOutputSpacing(this->m_FixedImageSpacing); m_MovingImageWarper->SetOutputDirection(this->m_FixedImageDirection); m_MovingImageWarper->SetInput(this->GetMovingImage()); m_MovingImageWarper->SetDisplacementField(this->GetDisplacementField()); m_MovingImageWarper->GetOutput()->SetRequestedRegion(this->GetDisplacementField()->GetRequestedRegion()); m_MovingImageWarper->Update(); this->m_MovingImageWarperOutput = this->m_MovingImageWarper->GetOutput(); // setup moving image interpolator for further access m_MovingImageInterpolator->SetInputImage(this->GetMovingImage()); // initialize metric computation variables m_SumOfSquaredDifference = 0.0; m_NumberOfPixelsProcessed = 0L; m_SumOfSquaredChange = 0.0; } template auto ESMDemonsRegistrationFunction::ComputeUpdate( const NeighborhoodType & it, void * gd, const FloatOffsetType & itkNotUsed(offset)) -> PixelType { auto * globalData = (GlobalDataStruct *)gd; PixelType update; IndexType FirstIndex = this->GetFixedImage()->GetLargestPossibleRegion().GetIndex(); IndexType LastIndex = this->GetFixedImage()->GetLargestPossibleRegion().GetIndex() + this->GetFixedImage()->GetLargestPossibleRegion().GetSize(); const IndexType index = it.GetIndex(); // Get fixed image related information // Note: no need to check if the index is within // fixed image buffer. This is done by the external filter. const auto fixedValue = static_cast(this->GetFixedImage()->GetPixel(index)); // Get moving image related information // check if the point was mapped outside of the moving image using // the "special value" NumericTraits::max() MovingPixelType movingPixValue = m_MovingImageWarperOutput->GetPixel(index); if (movingPixValue == NumericTraits::max()) { update.Fill(0.0); return update; } const auto movingValue = static_cast(movingPixValue); // We compute the gradient more or less by hand. // We first start by ignoring the image orientation and introduce it // afterwards CovariantVectorType usedOrientFreeGradientTimes2; if ((this->m_UseGradientType == GradientEnum::Symmetric) || (this->m_UseGradientType == GradientEnum::WarpedMoving)) { // we don't use a CentralDifferenceImageFunction here to be able to // check for NumericTraits::max() CovariantVectorType warpedMovingGradient; IndexType tmpIndex = index; for (unsigned int dim = 0; dim < ImageDimension; ++dim) { // bounds checking if (FirstIndex[dim] == LastIndex[dim] || index[dim] < FirstIndex[dim] || index[dim] >= LastIndex[dim]) { warpedMovingGradient[dim] = 0.0; continue; } else if (index[dim] == FirstIndex[dim]) { // compute derivative tmpIndex[dim] += 1; movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex); if (movingPixValue == NumericTraits::max()) { // weird crunched border case warpedMovingGradient[dim] = 0.0; } else { // forward difference warpedMovingGradient[dim] = static_cast(movingPixValue) - movingValue; warpedMovingGradient[dim] /= m_FixedImageSpacing[dim]; } tmpIndex[dim] -= 1; continue; } else if (index[dim] == (LastIndex[dim] - 1)) { // compute derivative tmpIndex[dim] -= 1; movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex); if (movingPixValue == NumericTraits::max()) { // weird crunched border case warpedMovingGradient[dim] = 0.0; } else { // backward difference warpedMovingGradient[dim] = movingValue - static_cast(movingPixValue); warpedMovingGradient[dim] /= m_FixedImageSpacing[dim]; } tmpIndex[dim] += 1; continue; } // compute derivative tmpIndex[dim] += 1; movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex); if (movingPixValue == NumericTraits::max()) { // backward difference warpedMovingGradient[dim] = movingValue; tmpIndex[dim] -= 2; movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex); if (movingPixValue == NumericTraits::max()) { // weird crunched border case warpedMovingGradient[dim] = 0.0; } else { // backward difference warpedMovingGradient[dim] -= static_cast(m_MovingImageWarperOutput->GetPixel(tmpIndex)); warpedMovingGradient[dim] /= m_FixedImageSpacing[dim]; } } else { warpedMovingGradient[dim] = static_cast(movingPixValue); tmpIndex[dim] -= 2; movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex); if (movingPixValue == NumericTraits::max()) { // forward difference warpedMovingGradient[dim] -= movingValue; warpedMovingGradient[dim] /= m_FixedImageSpacing[dim]; } else { // normal case, central difference warpedMovingGradient[dim] -= static_cast(movingPixValue); warpedMovingGradient[dim] *= 0.5 / m_FixedImageSpacing[dim]; } } tmpIndex[dim] += 1; } if (this->m_UseGradientType == GradientEnum::Symmetric) { // Compute orientation-free gradient with calculator const CovariantVectorType fixedGradient = m_FixedImageGradientCalculator->EvaluateAtIndex(index); usedOrientFreeGradientTimes2 = fixedGradient + warpedMovingGradient; } else if (this->m_UseGradientType == GradientEnum::WarpedMoving) { usedOrientFreeGradientTimes2 = warpedMovingGradient + warpedMovingGradient; } else { itkExceptionMacro("Unknown gradient type"); } } else if (this->m_UseGradientType == GradientEnum::Fixed) { // Compute orientation-free gradient with calculator const CovariantVectorType fixedGradient = m_FixedImageGradientCalculator->EvaluateAtIndex(index); usedOrientFreeGradientTimes2 = fixedGradient + fixedGradient; } else if (this->m_UseGradientType == GradientEnum::MappedMoving) { PointType mappedPoint; this->GetFixedImage()->TransformIndexToPhysicalPoint(index, mappedPoint); for (unsigned int j = 0; j < ImageDimension; ++j) { mappedPoint[j] += it.GetCenterPixel()[j]; } const CovariantVectorType mappedMovingGradient = m_MappedMovingImageGradientCalculator->Evaluate(mappedPoint); usedOrientFreeGradientTimes2 = mappedMovingGradient + mappedMovingGradient; } else { itkExceptionMacro("Unknown gradient type"); } const auto usedGradientTimes2 = this->GetFixedImage()->TransformLocalVectorToPhysicalVector(usedOrientFreeGradientTimes2); /** * Compute Update. * We avoid the mismatch in units between the two terms. * and avoid large step using a normalization term. */ const double usedGradientTimes2SquaredMagnitude = usedGradientTimes2.GetSquaredNorm(); const double speedValue = fixedValue - movingValue; if (itk::Math::abs(speedValue) < m_IntensityDifferenceThreshold) { update.Fill(0.0); } else { double denom; if (m_Normalizer > 0.0) { // "ITK-Thirion" normalization denom = usedGradientTimes2SquaredMagnitude + (itk::Math::sqr(speedValue) / m_Normalizer); } else { // least square solution of the system denom = usedGradientTimes2SquaredMagnitude; } if (denom < m_DenominatorThreshold) { update.Fill(0.0); } else { const double factor = 2.0 * speedValue / denom; for (unsigned int j = 0; j < ImageDimension; ++j) { update[j] = factor * usedGradientTimes2[j]; } } } // WARNING!! We compute the global data without taking into account the // current update step. // There are several reasons for that: If an exponential, a smoothing or any // other operation // is applied on the update field, we cannot compute the newMappedCenterPoint // here; and even // if we could, this would be an often unnecessary time-consuming task. if (globalData) { globalData->m_SumOfSquaredDifference += itk::Math::sqr(speedValue); globalData->m_NumberOfPixelsProcessed += 1; globalData->m_SumOfSquaredChange += update.GetSquaredNorm(); } return update; } template void ESMDemonsRegistrationFunction::ReleaseGlobalDataPointer(void * gd) const { const std::unique_ptr globalData(static_cast(gd)); const std::lock_guard lockGuard(m_MetricCalculationMutex); m_SumOfSquaredDifference += globalData->m_SumOfSquaredDifference; m_NumberOfPixelsProcessed += globalData->m_NumberOfPixelsProcessed; m_SumOfSquaredChange += globalData->m_SumOfSquaredChange; if (m_NumberOfPixelsProcessed) { m_Metric = m_SumOfSquaredDifference / static_cast(m_NumberOfPixelsProcessed); m_RMSChange = std::sqrt(m_SumOfSquaredChange / static_cast(m_NumberOfPixelsProcessed)); } } } // end namespace itk #endif