/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkLabeledPointSetToPointSetMetricv4_hxx #define itkLabeledPointSetToPointSetMetricv4_hxx #include "itkEuclideanDistancePointSetToPointSetMetricv4.h" #include "itkPrintHelper.h" #include namespace itk { template LabeledPointSetToPointSetMetricv4:: LabeledPointSetToPointSetMetricv4() { using DefaultMetricType = EuclideanDistancePointSetToPointSetMetricv4; auto euclideanMetric = DefaultMetricType::New(); this->m_PointSetMetric = euclideanMetric; this->m_UsePointSetData = true; } template void LabeledPointSetToPointSetMetricv4::Initialize() { if (!this->m_FixedPointSet->GetPointData() || this->m_FixedPointSet->GetPoints()->Size() != this->m_FixedPointSet->GetPointData()->Size() || !this->m_MovingPointSet->GetPointData() || this->m_MovingPointSet->GetPoints()->Size() != this->m_MovingPointSet->GetPointData()->Size()) { itkExceptionMacro("Each point of the point set must be associated with a label."); } this->DetermineCommonPointSetLabels(); // Create point set metric instantiations for each label typename LabelSetType::const_iterator it; for (it = this->m_CommonPointSetLabels.begin(); it != this->m_CommonPointSetLabels.end(); ++it) { typename PointSetMetricType::Pointer metric = dynamic_cast(this->m_PointSetMetric->Clone().GetPointer()); if (metric.IsNull()) { itkExceptionMacro("The metric pointer clone is nullptr."); } FixedPointSetPointer fixedPointSet = this->GetLabeledFixedPointSet(*it); MovingPointSetPointer movingPointSet = this->GetLabeledMovingPointSet(*it); metric->SetFixedPointSet(fixedPointSet); metric->SetMovingPointSet(movingPointSet); metric->SetFixedTransform(this->GetModifiableFixedTransform()); metric->SetMovingTransform(this->GetModifiableMovingTransform()); metric->SetCalculateValueAndDerivativeInTangentSpace(this->GetCalculateValueAndDerivativeInTangentSpace()); metric->SetStoreDerivativeAsSparseFieldForLocalSupportTransforms( this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()); metric->Initialize(); this->m_PointSetMetricClones.push_back(metric); } } template typename LabeledPointSetToPointSetMetricv4::MeasureType LabeledPointSetToPointSetMetricv4:: GetLocalNeighborhoodValue(const PointType & point, const LabelType & label) const { auto labelIt = std::find(this->m_CommonPointSetLabels.begin(), this->m_CommonPointSetLabels.end(), label); if (labelIt == this->m_CommonPointSetLabels.end()) { itkExceptionMacro("Label not found in common label set"); } else { unsigned int labelIndex = labelIt - this->m_CommonPointSetLabels.begin(); MeasureType value = this->m_PointSetMetricClones[labelIndex]->GetLocalNeighborhoodValue(point, label); return value; } } template void LabeledPointSetToPointSetMetricv4:: GetLocalNeighborhoodValueAndDerivative(const PointType & point, MeasureType & measure, LocalDerivativeType & localDerivative, const LabelType & label) const { auto labelIt = std::find(this->m_CommonPointSetLabels.begin(), this->m_CommonPointSetLabels.end(), label); if (labelIt == this->m_CommonPointSetLabels.end()) { itkExceptionMacro("Label not found in common label set"); } else { unsigned int labelIndex = labelIt - this->m_CommonPointSetLabels.begin(); this->m_PointSetMetricClones[labelIndex]->GetLocalNeighborhoodValueAndDerivative( point, measure, localDerivative, label); } } template typename LabeledPointSetToPointSetMetricv4:: FixedPointSetPointer LabeledPointSetToPointSetMetricv4:: GetLabeledFixedPointSet(const LabelType label) const { auto fixedPointSet = FixedPointSetType::New(); fixedPointSet->Initialize(); typename FixedPointSetType::PointIdentifier count{}; typename FixedPointSetType::PointsContainerConstIterator It = this->m_FixedPointSet->GetPoints()->Begin(); typename FixedPointSetType::PointDataContainerIterator ItD = this->m_FixedPointSet->GetPointData()->Begin(); while (It != this->m_FixedPointSet->GetPoints()->End()) { if (label == ItD.Value()) { fixedPointSet->SetPoint(count++, It.Value()); } ++It; ++ItD; } return fixedPointSet; } template typename LabeledPointSetToPointSetMetricv4:: MovingPointSetPointer LabeledPointSetToPointSetMetricv4:: GetLabeledMovingPointSet(const LabelType label) const { auto movingPointSet = MovingPointSetType::New(); movingPointSet->Initialize(); typename MovingPointSetType::PointIdentifier count{}; typename MovingPointSetType::PointsContainerConstIterator It = this->m_MovingPointSet->GetPoints()->Begin(); typename MovingPointSetType::PointDataContainerIterator ItD = this->m_MovingPointSet->GetPointData()->Begin(); while (It != this->m_MovingPointSet->GetPoints()->End()) { if (label == ItD.Value()) { movingPointSet->SetPoint(count++, It.Value()); } ++It; ++ItD; } return movingPointSet; } template void LabeledPointSetToPointSetMetricv4:: DetermineCommonPointSetLabels() { this->m_FixedPointSetLabels.clear(); this->m_MovingPointSetLabels.clear(); this->m_CommonPointSetLabels.clear(); if (this->m_FixedPointSet->GetNumberOfPoints() > 0) { typename FixedPointSetType::PointDataContainerIterator It = this->m_FixedPointSet->GetPointData()->Begin(); while (It != this->m_FixedPointSet->GetPointData()->End()) { if (std::find(this->m_FixedPointSetLabels.begin(), this->m_FixedPointSetLabels.end(), It.Value()) == this->m_FixedPointSetLabels.end()) { this->m_FixedPointSetLabels.push_back(It.Value()); } ++It; } } std::sort(this->m_FixedPointSetLabels.begin(), this->m_FixedPointSetLabels.end()); if (this->m_MovingPointSet->GetNumberOfPoints() > 0) { typename MovingPointSetType::PointDataContainerIterator It = this->m_MovingPointSet->GetPointData()->Begin(); while (It != this->m_MovingPointSet->GetPointData()->End()) { if (std::find(this->m_MovingPointSetLabels.begin(), this->m_MovingPointSetLabels.end(), It.Value()) == this->m_MovingPointSetLabels.end()) { this->m_MovingPointSetLabels.push_back(It.Value()); } ++It; } } std::sort(this->m_MovingPointSetLabels.begin(), this->m_MovingPointSetLabels.end()); LabelSetType uncommonLabelSet; typename LabelSetType::const_iterator itF; for (itF = this->m_FixedPointSetLabels.begin(); itF != this->m_FixedPointSetLabels.end(); ++itF) { if (std::find(this->m_MovingPointSetLabels.begin(), this->m_MovingPointSetLabels.end(), *itF) != this->m_MovingPointSetLabels.end()) { this->m_CommonPointSetLabels.push_back(*itF); } else { uncommonLabelSet.push_back(*itF); } } if (!uncommonLabelSet.empty()) { itkWarningMacro("The label sets are not bijective."); } } template void LabeledPointSetToPointSetMetricv4::PrintSelf( std::ostream & os, Indent indent) const { using namespace print_helper; Superclass::PrintSelf(os, indent); itkPrintSelfObjectMacro(PointSetMetric); os << indent << "PointSetMetricClones: " << m_PointSetMetricClones << std::endl; os << indent << "FixedPointSetLabels: " << m_FixedPointSetLabels << std::endl; os << indent << "MovingPointSetLabels: " << m_MovingPointSetLabels << std::endl; os << indent << "CommonPointSetLabels: " << m_CommonPointSetLabels << std::endl; } } // end namespace itk #endif