/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkImageToImageMetricv4GetValueAndDerivativeThreaderBase_hxx #define itkImageToImageMetricv4GetValueAndDerivativeThreaderBase_hxx #include "itkNumericTraits.h" #include "itkMakeUniqueForOverwrite.h" namespace itk { template ImageToImageMetricv4GetValueAndDerivativeThreaderBase:: ImageToImageMetricv4GetValueAndDerivativeThreaderBase() : m_GetValueAndDerivativePerThreadVariables(nullptr) , m_CachedNumberOfParameters(0) , m_CachedNumberOfLocalParameters(0) {} template void ImageToImageMetricv4GetValueAndDerivativeThreaderBase::BeforeThreadedExecution() { //--------------------------------------------------------------- // Resize the per thread memory objects. //----------------------------------------------------------------- // Cache some values this->m_CachedNumberOfParameters = this->m_Associate->GetNumberOfParameters(); this->m_CachedNumberOfLocalParameters = this->m_Associate->GetNumberOfLocalParameters(); /* Per-thread results */ const ThreadIdType numWorkUnitsUsed = this->GetNumberOfWorkUnitsUsed(); this->m_GetValueAndDerivativePerThreadVariables = make_unique_for_overwrite(numWorkUnitsUsed); if (this->m_Associate->GetComputeDerivative()) { for (ThreadIdType i = 0; i < numWorkUnitsUsed; ++i) { /* Allocate intermediary per-thread storage used to get results from * derived classes */ this->m_GetValueAndDerivativePerThreadVariables[i].LocalDerivatives.SetSize( this->m_CachedNumberOfLocalParameters); this->m_GetValueAndDerivativePerThreadVariables[i].MovingTransformJacobian.SetSize( this->m_Associate->VirtualImageDimension, this->m_CachedNumberOfLocalParameters); // Not pre-allocated since it may not be used // this->m_GetValueAndDerivativePerThreadVariables[i].MovingTransformJacobianPositional if (this->m_Associate->m_MovingTransform->GetTransformCategory() == MovingTransformType::TransformCategoryEnum::DisplacementField) { /* For transforms with local support, e.g. displacement field, * use a single derivative container that's updated by region * in multiple threads. * Initialization to zero is done in main class. */ itkDebugMacro("ImageToImageMetricv4::Initialize: transform HAS local support\n"); /* Set each per-thread object to point to m_DerivativeResult for efficiency. */ this->m_GetValueAndDerivativePerThreadVariables[i].Derivatives.SetData( this->m_Associate->m_DerivativeResult->data_block(), this->m_Associate->m_DerivativeResult->Size(), false); } else { itkDebugMacro("ImageToImageMetricv4::Initialize: transform does NOT have local support\n"); /* This size always comes from the moving image */ const NumberOfParametersType globalDerivativeSize = this->m_CachedNumberOfParameters; /* Global transforms get a separate derivatives container for each thread * that holds the result over a particular image region. * Use a CompensatedSummation value to provide for better consistency between * different number of threads. */ this->m_GetValueAndDerivativePerThreadVariables[i].CompensatedDerivatives.resize(globalDerivativeSize); } } } //--------------------------------------------------------------- // Set initial values. for (ThreadIdType workUnit = 0; workUnit < numWorkUnitsUsed; ++workUnit) { this->m_GetValueAndDerivativePerThreadVariables[workUnit].NumberOfValidPoints = SizeValueType{}; this->m_GetValueAndDerivativePerThreadVariables[workUnit].Measure = InternalComputationValueType{}; if (this->m_Associate->GetComputeDerivative()) { if (this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::TransformCategoryEnum::DisplacementField) { /* Be sure to init to 0 here, because the threader may not use * all the threads if the region is better split into fewer * subregions. */ for (NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; ++p) { this->m_GetValueAndDerivativePerThreadVariables[workUnit].CompensatedDerivatives[p].ResetToZero(); } } } } } template void ImageToImageMetricv4GetValueAndDerivativeThreaderBase::AfterThreadedExecution() { const ThreadIdType numWorkUnitsUsed = this->GetNumberOfWorkUnitsUsed(); /* Store the number of valid points the enclosing class \c * m_NumberOfValidPoints by collecting the valid points per thread. */ this->m_Associate->m_NumberOfValidPoints = SizeValueType{}; for (ThreadIdType i = 0; i < numWorkUnitsUsed; ++i) { this->m_Associate->m_NumberOfValidPoints += this->m_GetValueAndDerivativePerThreadVariables[i].NumberOfValidPoints; } itkDebugMacro("ImageToImageMetricv4: NumberOfValidPoints: " << this->m_Associate->m_NumberOfValidPoints); /* For global transforms, sum the derivatives from each region. */ if (this->m_Associate->GetComputeDerivative()) { if (this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::TransformCategoryEnum::DisplacementField) { for (NumberOfParametersType p = 0; p < this->m_Associate->GetNumberOfParameters(); ++p) { /* Use a compensated sum to be ready for when there is a very large number of threads */ CompensatedDerivativeValueType sum; sum.ResetToZero(); for (ThreadIdType i = 0; i < numWorkUnitsUsed; ++i) { sum += this->m_GetValueAndDerivativePerThreadVariables[i].CompensatedDerivatives[p].GetSum(); } (*(this->m_Associate->m_DerivativeResult))[p] += sum.GetSum(); } } } /* Check the number of valid points. If there aren't enough, * m_Value and m_DerivativeResult will get appropriate values assigned, * and a warning will be output. */ if (this->m_Associate->VerifyNumberOfValidPoints(this->m_Associate->m_Value, *(this->m_Associate->m_DerivativeResult))) { this->m_Associate->m_Value = MeasureType{}; /* Accumulate the metric value from threads and store the average. */ for (ThreadIdType threadId = 0; threadId < numWorkUnitsUsed; ++threadId) { this->m_Associate->m_Value += this->m_GetValueAndDerivativePerThreadVariables[threadId].Measure; } this->m_Associate->m_Value /= this->m_Associate->m_NumberOfValidPoints; /* For global transforms, calculate the average values */ if (this->m_Associate->GetComputeDerivative()) { if (this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::TransformCategoryEnum::DisplacementField) { *(this->m_Associate->m_DerivativeResult) /= this->m_Associate->m_NumberOfValidPoints; } } } } template bool ImageToImageMetricv4GetValueAndDerivativeThreaderBase::ProcessVirtualPoint( const VirtualIndexType & virtualIndex, const VirtualPointType & virtualPoint, const ThreadIdType threadId) { FixedImagePointType mappedFixedPoint; FixedImagePixelType mappedFixedPixelValue; FixedImageGradientType mappedFixedImageGradient; MovingImagePointType mappedMovingPoint; MovingImagePixelType mappedMovingPixelValue; MovingImageGradientType mappedMovingImageGradient; bool pointIsValid = false; MeasureType metricValueResult; /* Transform the point into fixed and moving spaces, and evaluate. * Do this in a try block to catch exceptions and print more useful info * then we otherwise get when exceptions are caught in MultiThreaderBase. */ try { pointIsValid = this->m_Associate->TransformAndEvaluateFixedPoint(virtualPoint, mappedFixedPoint, mappedFixedPixelValue); if (pointIsValid && this->m_Associate->GetComputeDerivative() && this->m_Associate->GetGradientSourceIncludesFixed()) { this->m_Associate->ComputeFixedImageGradientAtPoint(mappedFixedPoint, mappedFixedImageGradient); } } catch (const ExceptionObject & exc) { // NOTE: there must be a cleaner way to do this: std::string msg("Caught exception: \n"); msg += exc.what(); ExceptionObject err(__FILE__, __LINE__, msg); throw err; } if (!pointIsValid) { return pointIsValid; } try { pointIsValid = this->m_Associate->TransformAndEvaluateMovingPoint(virtualPoint, mappedMovingPoint, mappedMovingPixelValue); if (pointIsValid && this->m_Associate->GetComputeDerivative() && this->m_Associate->GetGradientSourceIncludesMoving()) { this->m_Associate->ComputeMovingImageGradientAtPoint(mappedMovingPoint, mappedMovingImageGradient); } } catch (const ExceptionObject & exc) { std::string msg("Caught exception: \n"); msg += exc.what(); ExceptionObject err(__FILE__, __LINE__, msg); throw err; } if (!pointIsValid) { return pointIsValid; } /* Call the user method in derived classes to do the specific * calculations for value and derivative. */ try { pointIsValid = this->ProcessPoint(virtualIndex, virtualPoint, mappedFixedPoint, mappedFixedPixelValue, mappedFixedImageGradient, mappedMovingPoint, mappedMovingPixelValue, mappedMovingImageGradient, metricValueResult, this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives, threadId); } catch (const ExceptionObject & exc) { // NOTE: there must be a cleaner way to do this: std::string msg("Exception in GetValueAndDerivativeProcessPoint:\n"); msg += exc.what(); ExceptionObject err(__FILE__, __LINE__, msg); throw err; } if (pointIsValid) { this->m_GetValueAndDerivativePerThreadVariables[threadId].NumberOfValidPoints++; this->m_GetValueAndDerivativePerThreadVariables[threadId].Measure += metricValueResult; if (this->m_Associate->GetComputeDerivative()) { this->StorePointDerivativeResult(virtualIndex, threadId); } } return pointIsValid; } template void ImageToImageMetricv4GetValueAndDerivativeThreaderBase:: StorePointDerivativeResult(const VirtualIndexType & virtualIndex, const ThreadIdType threadId) { if (this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::TransformCategoryEnum::DisplacementField) { /* Global support */ if (this->m_Associate->GetUseFloatingPointCorrection()) { DerivativeValueType correctionResolution = this->m_Associate->GetFloatingPointCorrectionResolution(); for (NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; ++p) { auto test = static_cast( this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p] * correctionResolution); this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p] = static_cast(test / correctionResolution); } } for (NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; ++p) { this->m_GetValueAndDerivativePerThreadVariables[threadId].CompensatedDerivatives[p] += this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p]; } } else { // Update derivative at some index // this requires the moving image displacement field to be // same size as virtual image, and that VirtualImage PixelType // is scalar (which is verified during Metric initialization). try { OffsetValueType offset = this->m_Associate->ComputeParameterOffsetFromVirtualIndex(virtualIndex, this->m_CachedNumberOfLocalParameters); for (NumberOfParametersType i = 0; i < this->m_CachedNumberOfLocalParameters; ++i) { /* Be sure to *add* here and not assign. Required for proper behavior * with multi-variate metric. */ this->m_GetValueAndDerivativePerThreadVariables[threadId].Derivatives[offset + i] += this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[i]; } } catch (const ExceptionObject & exc) { std::string msg("Caught exception: \n"); msg += exc.what(); ExceptionObject err(__FILE__, __LINE__, msg); throw err; } } } template bool ImageToImageMetricv4GetValueAndDerivativeThreaderBase::GetComputeDerivative() const { return this->m_Associate->GetComputeDerivative(); } } // end namespace itk #endif