/*=========================================================================
*
* Copyright NumFOCUS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/
#ifndef itkExpectationMaximizationMixtureModelEstimator_h
#define itkExpectationMaximizationMixtureModelEstimator_h
#include "ITKStatisticsExport.h"
#include "itkMixtureModelComponentBase.h"
#include "itkGaussianMembershipFunction.h"
#include "itkSimpleDataObjectDecorator.h"
namespace itk
{
namespace Statistics
{
/** \class ExpectationMaximizationMixtureModelEstimatorEnums
* \brief Contains all enum classes used by ExpectationMaximizationMixtureModelEstimator class.
* \ingroup ITKStatistics
*/
class ExpectationMaximizationMixtureModelEstimatorEnums
{
public:
/** \class TERMINATION_CODE
* \ingroup ITKStatistics
* Termination status after running optimization */
enum class TERMINATION_CODE : uint8_t
{
CONVERGED = 0,
NOT_CONVERGED = 1
};
};
// Define how to print enumeration
extern ITKStatistics_EXPORT std::ostream &
operator<<(std::ostream & out, const ExpectationMaximizationMixtureModelEstimatorEnums::TERMINATION_CODE value);
/**
* \class ExpectationMaximizationMixtureModelEstimator
* \brief This class generates the parameter estimates for a mixture
* model using expectation maximization strategy.
*
* The first template argument is the type of the target sample
* data. This estimator expects one or more mixture model component
* objects of the classes derived from the
* MixtureModelComponentBase. The actual component (or module)
* parameters are updated by each component. Users can think this
* class as a strategy or a integration point for the EM
* procedure. The initial proportion (SetInitialProportions), the
* input sample (SetSample), the mixture model components
* (AddComponent), and the maximum iteration (SetMaximumIteration) are
* required. The EM procedure terminates when the current iteration
* reaches the maximum iteration or the model parameters converge.
*
* Recent API changes:
* The static const macro to get the length of a measurement vector,
* \c MeasurementVectorSize has been removed to allow the length of a measurement
* vector to be specified at run time. It is now obtained at run time from the
* sample set as input. Please use the function
* GetMeasurementVectorSize() to get the length.
*
* \sa MixtureModelComponentBase, GaussianMixtureModelComponent
* \ingroup ITKStatistics
*
* \sphinx
* \sphinxexample{Numerics/Statistics/2DGaussianMixtureModelExpectMax,2D Gaussian Mixture Model Expectation Maximum}
* \sphinxexample{Numerics/Statistics/DistributionOfPixelsUsingGMM,Distribution Of Pixels Using GMM EM}
* \sphinxexample{Numerics/Statistics/DistributeSamplingUsingGMM,Distribute Sampling Using GMM EM}
* \endsphinx
*/
template
class ITK_TEMPLATE_EXPORT ExpectationMaximizationMixtureModelEstimator : public Object
{
public:
/** Standard class type alias */
using Self = ExpectationMaximizationMixtureModelEstimator;
using Superclass = Object;
using Pointer = SmartPointer;
using ConstPointer = SmartPointer;
/** \see LightObject::GetNameOfClass() */
itkOverrideGetNameOfClassMacro(ExpectationMaximizationMixtureModelEstimator);
itkNewMacro(Self);
/** TSample template argument related type alias */
using SampleType = TSample;
using MeasurementType = typename TSample::MeasurementType;
using MeasurementVectorType = typename TSample::MeasurementVectorType;
/** Typedef required to generate dataobject decorated output that can
* be plugged into SampleClassifierFilter */
using GaussianMembershipFunctionType = GaussianMembershipFunction;
using GaussianMembershipFunctionPointer = typename GaussianMembershipFunctionType::Pointer;
using MembershipFunctionType = MembershipFunctionBase;
using MembershipFunctionPointer = typename MembershipFunctionType::ConstPointer;
using MembershipFunctionVectorType = std::vector;
using MembershipFunctionVectorObjectType = SimpleDataObjectDecorator;
using MembershipFunctionVectorObjectPointer = typename MembershipFunctionVectorObjectType::Pointer;
/** Type of the mixture model component base class */
using ComponentType = MixtureModelComponentBase;
/** Type of the component pointer storage */
using ComponentVectorType = std::vector;
/** Type of the membership function base class */
using ComponentMembershipFunctionType = MembershipFunctionBase;
/** Type of the array of the proportion values */
using ProportionVectorType = Array;
/** Sets the target data that will be classified by this */
void
SetSample(const TSample * sample);
/** Returns the target data */
const TSample *
GetSample() const;
/** Set/Gets the initial proportion values. The size of proportion
* vector should be same as the number of component (or classes) */
void
SetInitialProportions(ProportionVectorType & proportions);
const ProportionVectorType &
GetInitialProportions() const;
/** Gets the result proportion values */
const ProportionVectorType &
GetProportions() const;
/** type alias for decorated array of proportion */
using MembershipFunctionsWeightsArrayObjectType = SimpleDataObjectDecorator;
using MembershipFunctionsWeightsArrayPointer = typename MembershipFunctionsWeightsArrayObjectType::Pointer;
/** Get method for data decorated Membership functions weights array */
const MembershipFunctionsWeightsArrayObjectType *
GetMembershipFunctionsWeightsArray() const;
/** Set/Gets the maximum number of iterations. When the optimization
* process reaches the maximum number of iterations, even if the
* class parameters aren't converged, the optimization process
* stops. */
void
SetMaximumIteration(int numberOfIterations);
int
GetMaximumIteration() const;
/** Gets the current iteration. */
int
GetCurrentIteration()
{
return m_CurrentIteration;
}
/** Adds a new component (or class). */
int
AddComponent(ComponentType * component);
/** Gets the total number of classes currently plugged in. */
unsigned int
GetNumberOfComponents() const;
/** Runs the optimization process. */
void
Update();
using TERMINATION_CODE_ENUM = ExpectationMaximizationMixtureModelEstimatorEnums::TERMINATION_CODE;
#if !defined(ITK_LEGACY_REMOVE)
/**Exposes enums values for backwards compatibility*/
static constexpr TERMINATION_CODE_ENUM CONVERGED = TERMINATION_CODE_ENUM::CONVERGED;
static constexpr TERMINATION_CODE_ENUM NOT_CONVERGED = TERMINATION_CODE_ENUM::NOT_CONVERGED;
#endif
/** Gets the termination status */
TERMINATION_CODE_ENUM
GetTerminationCode() const;
/** Gets the membership function specified by componentIndex
argument. */
ComponentMembershipFunctionType *
GetComponentMembershipFunction(int componentIndex) const;
/** Output Membership function vector containing the membership functions with
* the final optimized parameters */
const MembershipFunctionVectorObjectType *
GetOutput() const;
protected:
ExpectationMaximizationMixtureModelEstimator();
~ExpectationMaximizationMixtureModelEstimator() override = default;
void
PrintSelf(std::ostream & os, Indent indent) const override;
bool
CalculateDensities();
double
CalculateExpectation() const;
bool
UpdateComponentParameters();
bool
UpdateProportions();
/** Starts the estimation process */
void
GenerateData();
private:
/** Target data sample pointer*/
const TSample * m_Sample{};
int m_MaxIteration{ 100 };
int m_CurrentIteration{ 0 };
TERMINATION_CODE_ENUM m_TerminationCode{ TERMINATION_CODE_ENUM::NOT_CONVERGED };
ComponentVectorType m_ComponentVector{};
ProportionVectorType m_InitialProportions{};
ProportionVectorType m_Proportions{};
MembershipFunctionVectorObjectPointer m_MembershipFunctionsObject{};
MembershipFunctionsWeightsArrayPointer m_MembershipFunctionsWeightArrayObject{};
}; // end of class
} // end of namespace Statistics
} // end of namespace itk
#ifndef ITK_MANUAL_INSTANTIATION
# include "itkExpectationMaximizationMixtureModelEstimator.hxx"
#endif
#endif