/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #include "itkLBFGS2Optimizerv4.h" #include "itkMacro.h" #include "itkMath.h" #include "lbfgs.h" namespace itk { class LBFGS2Optimizerv4::PrivateImplementationHolder { public: lbfgs_parameter_t m_Parameters; }; LBFGS2Optimizerv4 ::LBFGS2Optimizerv4() : m_Pimpl(new PrivateImplementationHolder) { // Initialize to default parameters lbfgs_parameter_init(&m_Pimpl->m_Parameters); m_StatusCode = 100; } LBFGS2Optimizerv4::~LBFGS2Optimizerv4() = default; void LBFGS2Optimizerv4 ::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); os << indent << "m: " << m_Pimpl->m_Parameters.m << std::endl; os << indent << "epsilon: " << m_Pimpl->m_Parameters.epsilon << std::endl; os << indent << "past: " << m_Pimpl->m_Parameters.past << std::endl; os << indent << "delta: " << m_Pimpl->m_Parameters.delta << std::endl; os << indent << "max_iterations: " << m_Pimpl->m_Parameters.max_iterations << std::endl; os << indent << "linesearch: " << m_Pimpl->m_Parameters.linesearch << std::endl; os << indent << "max_linesearch: " << m_Pimpl->m_Parameters.max_linesearch << std::endl; os << indent << "min_step: " << m_Pimpl->m_Parameters.min_step << std::endl; os << indent << "max_step: " << m_Pimpl->m_Parameters.max_step << std::endl; os << indent << "ftol: " << m_Pimpl->m_Parameters.ftol << std::endl; os << indent << "wolfe: " << m_Pimpl->m_Parameters.wolfe << std::endl; os << indent << "gtol: " << m_Pimpl->m_Parameters.gtol << std::endl; os << indent << "xtol: " << m_Pimpl->m_Parameters.xtol << std::endl; os << indent << "orthantwise_c: " << m_Pimpl->m_Parameters.orthantwise_c << std::endl; os << indent << "orthantwise_start: " << m_Pimpl->m_Parameters.orthantwise_start << std::endl; os << indent << "orthantwise_end: " << m_Pimpl->m_Parameters.orthantwise_end << std::endl; } // Register callbacks and call the lbfgs routine void LBFGS2Optimizerv4 ::StartOptimization(bool doOnlyInitialization) { // Check if everything is setup correctly Superclass::StartOptimization(doOnlyInitialization); if (this->GetMetric()->HasLocalSupport()) { itkExceptionMacro("The assigned transform has local-support. This is not supported for this optimizer. See the " "optimizer documentation."); } this->InvokeEvent(StartEvent()); // Copy parameters const ParametersType & parameters = this->m_Metric->GetParameters(); int N = parameters.GetSize(); if (N == 0) { itkExceptionMacro(<< "Optimizer parameters are not initialized."); } // TODO: only needed if SSE is enabled LBFGS2Optimizerv4::PrecisionType * x = lbfgs_malloc(N); std::memcpy(x, parameters.data_block(), sizeof(LBFGS2Optimizerv4::PrecisionType) * N); // Run lbfgs m_StatusCode = lbfgs(N, x, &this->m_CurrentMetricValue, LBFGS2Optimizerv4::EvaluateCostCallback, LBFGS2Optimizerv4::UpdateProgressCallback, this, &m_Pimpl->m_Parameters); // Match the behavior of other optimizer setting the current // iteration to the max when iteration limit is reached if (m_StatusCode == LBFGSERR_MAXIMUMITERATION) { ++this->m_CurrentIteration; } // Copy results ParametersType optimizedParameters(N); std::memcpy(optimizedParameters.data_block(), x, sizeof(LBFGS2Optimizerv4::PrecisionType) * N); lbfgs_free(x); this->m_Metric->SetParameters(optimizedParameters); } // LBFGS method callbacks LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::EvaluateCostCallback(void * instance, const LBFGS2Optimizerv4::PrecisionType * x, LBFGS2Optimizerv4::PrecisionType * g, const int n, const LBFGS2Optimizerv4::PrecisionType step) { auto * optimizer = static_cast(instance); return optimizer->EvaluateCost(x, g, n, step); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::EvaluateCost(const LBFGS2Optimizerv4::PrecisionType * x, LBFGS2Optimizerv4::PrecisionType * g, const int n, const LBFGS2Optimizerv4::PrecisionType) { ParametersType xItk(n); std::memcpy(xItk.data_block(), x, n * sizeof(LBFGS2Optimizerv4::PrecisionType)); DerivativeType gItk(n); gItk.SetData(g, n, false); MeasureType value; this->m_Metric->SetParameters(xItk); this->m_Metric->GetValueAndDerivative(value, gItk); gItk *= -1; return value; } int LBFGS2Optimizerv4::UpdateProgressCallback(void * instance, const LBFGS2Optimizerv4::PrecisionType * x, const LBFGS2Optimizerv4::PrecisionType * g, const LBFGS2Optimizerv4::PrecisionType fx, const LBFGS2Optimizerv4::PrecisionType xnorm, const LBFGS2Optimizerv4::PrecisionType gnorm, const LBFGS2Optimizerv4::PrecisionType step, int n, int k, int ls) { auto * optimizer = static_cast(instance); return optimizer->UpdateProgress(x, g, fx, xnorm, gnorm, step, n, k, ls); } int LBFGS2Optimizerv4::UpdateProgress(const LBFGS2Optimizerv4::PrecisionType * x, const LBFGS2Optimizerv4::PrecisionType * g, const LBFGS2Optimizerv4::PrecisionType fx, const LBFGS2Optimizerv4::PrecisionType xnorm, const LBFGS2Optimizerv4::PrecisionType gnorm, const LBFGS2Optimizerv4::PrecisionType step, int, int k, int ls) { // Convert to 0-based ITK iteration counting this->m_CurrentIteration = k - 1; this->m_CurrentMetricValue = fx; m_CurrentGradient = g; m_CurrentParameter = x; m_CurrentParameterNorm = xnorm; m_CurrentGradientNorm = gnorm; m_CurrentStepSize = step; m_CurrentNumberOfEvaluations = ls; this->InvokeEvent(IterationEvent()); return 0; } const LBFGS2Optimizerv4::StopConditionReturnStringType LBFGS2Optimizerv4::GetStopConditionDescription() const { switch (m_StatusCode) { case 100: return "Optimization not started"; case LBFGS_SUCCESS: return "Converged"; case LBFGS_ALREADY_MINIMIZED: return "Already minimized"; case LBFGSERR_UNKNOWNERROR: return "Unknown error"; case LBFGSERR_LOGICERROR: return "Logic error"; case LBFGSERR_OUTOFMEMORY: return "Out of memory"; case LBFGSERR_CANCELED: return "Optimization canceled"; case LBFGSERR_INVALID_N: return "Invalid number of variables"; case LBFGSERR_INVALID_N_SSE: return "Invalid number of variables for SSE"; case LBFGSERR_INVALID_X_SSE: return "Invalid alignment of variables for SSE"; case LBFGSERR_INVALID_EPSILON: return "Invalid solution accuracy parameter"; case LBFGSERR_INVALID_TESTPERIOD: return "Invalid delta convergence distance"; case LBFGSERR_INVALID_DELTA: return "Invalid delta convergence tolerance"; case LBFGSERR_INVALID_LINESEARCH: return "Invalid linesearch specified"; case LBFGSERR_INVALID_MINSTEP: return "Invalid minimum linesearch step"; case LBFGSERR_INVALID_MAXSTEP: return "Invalid maximum linesearch step"; case LBFGSERR_INVALID_FTOL: return "Invalid linesearch accuracy"; case LBFGSERR_INVALID_WOLFE: return "Invalid wolfe coefficient"; case LBFGSERR_INVALID_GTOL: return "Invalid lnesearch gradient accuracy"; case LBFGSERR_INVALID_XTOL: return "Invalid machine precision tolerance"; case LBFGSERR_INVALID_MAXLINESEARCH: return "Invalid maximum linesearch iterations"; case LBFGSERR_INVALID_ORTHANTWISE: return "Invalid orthantwise coefficient"; case LBFGSERR_INVALID_ORTHANTWISE_START: return "Invalid orthantwise start"; case LBFGSERR_INVALID_ORTHANTWISE_END: return "Invalid orthantwise end"; case LBFGSERR_OUTOFINTERVAL: return "LineSearch went outside of the uncertainty interval"; case LBFGSERR_INCORRECT_TMINMAX: return "Interval of uncertainty became too small"; case LBFGSERR_ROUNDING_ERROR: return "A rounding error occurred or line-search steps have an insufficient reduction"; case LBFGSERR_MINIMUMSTEP: return "LineSearch became smaller than the minimum linesearch step size"; case LBFGSERR_MAXIMUMSTEP: return "LineSearch became bigger than the maximum linesearch step size"; case LBFGSERR_MAXIMUMLINESEARCH: return "LineSearch reached maximum umber of iterations"; case LBFGSERR_MAXIMUMITERATION: return "Reached maximum number of iterations"; case LBFGSERR_WIDTHTOOSMALL: return "Relative width of the interval became too small (less than machine tolerance)"; case LBFGSERR_INVALIDPARAMETERS: return "Negative line search step occurred"; case LBFGSERR_INCREASEGRADIENT: return "Current search direction increases objective function"; } return "Unknown status"; } // Disallow setting scale and weight void LBFGS2Optimizerv4 ::SetScales(const ScalesType &) { itkWarningMacro(<< "LBFGS optimizer does not support scaling. All scales are set to one."); m_Scales.SetSize(this->m_Metric->GetNumberOfLocalParameters()); m_Scales.Fill(NumericTraits::OneValue()); this->m_ScalesAreIdentity = true; } void LBFGS2Optimizerv4 ::SetWeights(const ScalesType) { itkWarningMacro(<< "LBFGS optimizer does not support weights. All weights are set to one."); m_Weights.SetSize(this->m_Metric->GetNumberOfLocalParameters()); m_Weights.Fill(NumericTraits::OneValue()); this->m_WeightsAreIdentity = true; } // A bunch of Set/Get methods for setting lbfgs parameters void LBFGS2Optimizerv4::SetHessianApproximationAccuracy(int m) { m_Pimpl->m_Parameters.m = m; this->Modified(); } int LBFGS2Optimizerv4::GetHessianApproximationAccuracy() const { return m_Pimpl->m_Parameters.m; } void LBFGS2Optimizerv4::SetSolutionAccuracy(LBFGS2Optimizerv4::PrecisionType epsilon) { m_Pimpl->m_Parameters.epsilon = epsilon; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetSolutionAccuracy() const { return m_Pimpl->m_Parameters.epsilon; } void LBFGS2Optimizerv4::SetDeltaConvergenceDistance(int nPast) { m_Pimpl->m_Parameters.past = nPast; this->Modified(); } int LBFGS2Optimizerv4::GetDeltaConvergenceDistance() const { return m_Pimpl->m_Parameters.past; } void LBFGS2Optimizerv4::SetDeltaConvergenceTolerance(LBFGS2Optimizerv4::PrecisionType tol) { m_Pimpl->m_Parameters.delta = tol; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetDeltaConvergenceTolerance() const { return m_Pimpl->m_Parameters.delta; } void LBFGS2Optimizerv4::SetMaximumIterations(int maxIterations) { m_Pimpl->m_Parameters.max_iterations = maxIterations; this->Modified(); } int LBFGS2Optimizerv4::GetMaximumIterations() const { return m_Pimpl->m_Parameters.max_iterations; } // translate to lbfgs.h enum // this translation is kind of annoying and error prone // but avoids exposing lbfgs.h void LBFGS2Optimizerv4::SetLineSearch(const LBFGS2Optimizerv4::LineSearchMethodEnum & linesearch) { // int lbfgsLineSearch = LBFGS_LINESEARCH_DEFAULT; if (linesearch == LineSearchMethodEnum::LINESEARCH_BACKTRACKING) { lbfgsLineSearch = LBFGS_LINESEARCH_BACKTRACKING; } else if (linesearch == LineSearchMethodEnum::LINESEARCH_BACKTRACKING_ARMIJO) { lbfgsLineSearch = LBFGS_LINESEARCH_BACKTRACKING_ARMIJO; } else if (linesearch == LineSearchMethodEnum::LINESEARCH_BACKTRACKING_WOLFE) { lbfgsLineSearch = LBFGS_LINESEARCH_BACKTRACKING_WOLFE; } else if (linesearch == LineSearchMethodEnum::LINESEARCH_BACKTRACKING_STRONG_WOLFE) { lbfgsLineSearch = LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE; } else if (linesearch == LineSearchMethodEnum::LINESEARCH_MORETHUENTE) { lbfgsLineSearch = LBFGS_LINESEARCH_MORETHUENTE; } m_Pimpl->m_Parameters.linesearch = lbfgsLineSearch; this->Modified(); } LBFGS2Optimizerv4::LineSearchMethodEnum LBFGS2Optimizerv4::GetLineSearch() const { LineSearchMethodEnum linesearch = LineSearchMethodEnum::LINESEARCH_DEFAULT; int lbfgsLineSearch = m_Pimpl->m_Parameters.linesearch; if (lbfgsLineSearch == LBFGS_LINESEARCH_BACKTRACKING) { linesearch = LineSearchMethodEnum::LINESEARCH_BACKTRACKING; } else if (lbfgsLineSearch == LBFGS_LINESEARCH_BACKTRACKING_ARMIJO) { linesearch = LineSearchMethodEnum::LINESEARCH_BACKTRACKING_ARMIJO; } else if (lbfgsLineSearch == LBFGS_LINESEARCH_BACKTRACKING_WOLFE) { linesearch = LineSearchMethodEnum::LINESEARCH_BACKTRACKING_WOLFE; } else if (lbfgsLineSearch == LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { linesearch = LineSearchMethodEnum::LINESEARCH_BACKTRACKING_STRONG_WOLFE; } else if (lbfgsLineSearch == LBFGS_LINESEARCH_MORETHUENTE) { linesearch = LineSearchMethodEnum::LINESEARCH_MORETHUENTE; } return linesearch; } void LBFGS2Optimizerv4::SetMaximumLineSearchEvaluations(int n) { m_Pimpl->m_Parameters.max_linesearch = n; this->Modified(); } int LBFGS2Optimizerv4::GetMaximumLineSearchEvaluations() const { return m_Pimpl->m_Parameters.max_linesearch; } void LBFGS2Optimizerv4::SetMinimumLineSearchStep(LBFGS2Optimizerv4::PrecisionType step) { m_Pimpl->m_Parameters.min_step = step; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetMinimumLineSearchStep() const { return m_Pimpl->m_Parameters.min_step; } void LBFGS2Optimizerv4::SetMaximumLineSearchStep(LBFGS2Optimizerv4::PrecisionType step) { m_Pimpl->m_Parameters.max_step = step; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetMaximumLineSearchStep() const { return m_Pimpl->m_Parameters.max_step; } void LBFGS2Optimizerv4::SetLineSearchAccuracy(LBFGS2Optimizerv4::PrecisionType ftol) { m_Pimpl->m_Parameters.ftol = ftol; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetLineSearchAccuracy() const { return m_Pimpl->m_Parameters.ftol; } void LBFGS2Optimizerv4::SetWolfeCoefficient(LBFGS2Optimizerv4::PrecisionType wc) { m_Pimpl->m_Parameters.wolfe = wc; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetWolfeCoefficient() const { return m_Pimpl->m_Parameters.wolfe; } void LBFGS2Optimizerv4::SetLineSearchGradientAccuracy(LBFGS2Optimizerv4::PrecisionType gtol) { m_Pimpl->m_Parameters.gtol = gtol; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetLineSearchGradientAccuracy() const { return m_Pimpl->m_Parameters.gtol; } void LBFGS2Optimizerv4::SetMachinePrecisionTolerance(LBFGS2Optimizerv4::PrecisionType xtol) { m_Pimpl->m_Parameters.xtol = xtol; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetMachinePrecisionTolerance() const { return m_Pimpl->m_Parameters.xtol; } void LBFGS2Optimizerv4::SetOrthantwiseCoefficient(LBFGS2Optimizerv4::PrecisionType orthant_c) { m_Pimpl->m_Parameters.orthantwise_c = orthant_c; this->Modified(); } LBFGS2Optimizerv4::PrecisionType LBFGS2Optimizerv4::GetOrthantwiseCoefficient() const { return m_Pimpl->m_Parameters.orthantwise_c; } void LBFGS2Optimizerv4::SetOrthantwiseStart(int start) { m_Pimpl->m_Parameters.orthantwise_start = start; this->Modified(); } int LBFGS2Optimizerv4::GetOrthantwiseStart() const { return m_Pimpl->m_Parameters.orthantwise_start; } void LBFGS2Optimizerv4::SetOrthantwiseEnd(int end) { m_Pimpl->m_Parameters.orthantwise_end = end; this->Modified(); } int LBFGS2Optimizerv4::GetOrthantwiseEnd() const { return m_Pimpl->m_Parameters.orthantwise_end; } /** Print enum values */ std::ostream & operator<<(std::ostream & out, const LBFGS2Optimizerv4Enums::LineSearchMethod value) { return out << [value] { switch (value) { case LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_DEFAULT: return "itk::LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_DEFAULT"; // case LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_MORETHUENTE: // return "itk::LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_MORETHUENTE"; case LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING_ARMIJO: return "itk::LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING_ARMIJO"; case LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING: return "itk::LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING"; // case LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING_WOLFE: // return "itk::LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING_WOLFE"; case LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING_STRONG_WOLFE: return "itk::LBFGS2Optimizerv4Enums::LineSearchMethod::LINESEARCH_BACKTRACKING_STRONG_WOLFE"; default: return "INVALID VALUE FOR itk::LBFGS2Optimizerv4Enums::LineSearchMethod"; } }(); } } // end namespace itk