/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkImageClassifierBase_hxx #define itkImageClassifierBase_hxx #include "itkImageClassifierBase.h" namespace itk { /** * PrintSelf */ template void ImageClassifierBase::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); os << indent << "General Image Classifier / Clusterer" << std::endl; os << indent << "ClassifiedImage: "; os << m_ClassifiedImage.GetPointer() << std::endl; os << indent << "InputImage: "; os << m_InputImage.GetPointer() << std::endl; } // end PrintSelf /** * Generate data (start the classification process) */ template void ImageClassifierBase::GenerateData() { this->Classify(); } // end Generate data //------------------------------------------------------------------ // The core function where classification is carried out //------------------------------------------------------------------ template void ImageClassifierBase::Classify() { ClassifiedImagePointer classifiedImage = this->GetClassifiedImage(); // Check if the an output buffer has been allocated if (!classifiedImage) { this->Allocate(); // To trigger the pipeline process this->Modified(); } //-------------------------------------------------------------------- // Set the iterators and the pixel type definition for the input image //------------------------------------------------------------------- InputImageConstPointer inputImage = this->GetInputImage(); InputImageConstIterator inIt(inputImage, inputImage->GetBufferedRegion()); //-------------------------------------------------------------------- // Set the iterators and the pixel type definition for the classified image //-------------------------------------------------------------------- classifiedImage = this->GetClassifiedImage(); ClassifiedImageIterator classifiedIt(classifiedImage, classifiedImage->GetBufferedRegion()); //-------------------------------------------------------------------- // Set up the vector to store the image data InputImagePixelType inputImagePixel; ClassifiedImagePixelType outputClassifiedLabel; // Set up the storage containers to record the probability // measures for each class. unsigned int numberOfClasses = this->GetNumberOfMembershipFunctions(); std::vector discriminantScores; discriminantScores.resize(numberOfClasses); unsigned int classLabel; unsigned int classIndex; // support progress methods/callbacks SizeValueType totalPixels = inputImage->GetBufferedRegion().GetNumberOfPixels(); SizeValueType updateVisits = totalPixels / 10; if (updateVisits < 1) { updateVisits = 1; } int k = 0; for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++classifiedIt, ++k) { if (!(k % updateVisits)) { this->UpdateProgress((float)k / (float)totalPixels); } // Read the input vector inputImagePixel = inIt.Get(); for (classIndex = 0; classIndex < numberOfClasses; classIndex++) { discriminantScores[classIndex] = (this->GetMembershipFunction(classIndex))->Evaluate(inputImagePixel); } classLabel = static_cast(this->GetDecisionRule()->Evaluate(discriminantScores)); outputClassifiedLabel = ClassifiedImagePixelType(classLabel); classifiedIt.Set(outputClassifiedLabel); } // end for (looping through the dataset) } // end Classify /** * Allocate */ template void ImageClassifierBase::Allocate() { InputImageConstPointer inputImage = this->GetInputImage(); InputImageSizeType inputImageSize = inputImage->GetBufferedRegion().GetSize(); ClassifiedImagePointer classifiedImage = TClassifiedImage::New(); this->SetClassifiedImage(classifiedImage); typename TClassifiedImage::IndexType classifiedImageIndex; classifiedImageIndex.Fill(0); typename TClassifiedImage::RegionType classifiedImageRegion; classifiedImageRegion.SetSize(inputImageSize); classifiedImageRegion.SetIndex(classifiedImageIndex); classifiedImage->SetLargestPossibleRegion(classifiedImageRegion); classifiedImage->SetBufferedRegion(classifiedImageRegion); classifiedImage->Allocate(); } template std::vector ImageClassifierBase::GetPixelMembershipValue(const InputImagePixelType inputImagePixel) { unsigned int numberOfClasses = this->GetNumberOfClasses(); std::vector pixelMembershipValue(numberOfClasses); for (unsigned int classIndex = 0; classIndex < numberOfClasses; classIndex++) { pixelMembershipValue[classIndex] = (this->GetMembershipFunction(classIndex))->Evaluate(inputImagePixel); } // Return the membership value of the return pixelMembershipValue; } } // namespace itk #endif