/*========================================================================= * * Copyright NumFOCUS * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkImageClassifierBase_hxx #define itkImageClassifierBase_hxx namespace itk { template void ImageClassifierBase::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); itkPrintSelfObjectMacro(InputImage); itkPrintSelfObjectMacro(ClassifiedImage); } template void ImageClassifierBase::GenerateData() { this->Classify(); } template void ImageClassifierBase::Classify() { ClassifiedImagePointer classifiedImage = this->GetClassifiedImage(); // Check if the an output buffer has been allocated if (!classifiedImage) { this->Allocate(); // To trigger the pipeline process this->Modified(); } // Set the iterators and the pixel type definition for the input image InputImageConstPointer inputImage = this->GetInputImage(); InputImageConstIterator inIt(inputImage, inputImage->GetBufferedRegion()); // Set the iterators and the pixel type definition for the classified image classifiedImage = this->GetClassifiedImage(); ClassifiedImageIterator classifiedIt(classifiedImage, classifiedImage->GetBufferedRegion()); // Set up the vector to store the image data InputImagePixelType inputImagePixel; ClassifiedImagePixelType outputClassifiedLabel; // Set up the storage containers to record the probability // measures for each class. unsigned int numberOfClasses = this->GetNumberOfMembershipFunctions(); std::vector discriminantScores; discriminantScores.resize(numberOfClasses); unsigned int classLabel; unsigned int classIndex; // support progress methods/callbacks SizeValueType totalPixels = inputImage->GetBufferedRegion().GetNumberOfPixels(); SizeValueType updateVisits = totalPixels / 10; if (updateVisits < 1) { updateVisits = 1; } int k = 0; for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++classifiedIt, ++k) { if (!(k % updateVisits)) { this->UpdateProgress(static_cast(k) / static_cast(totalPixels)); } // Read the input vector inputImagePixel = inIt.Get(); for (classIndex = 0; classIndex < numberOfClasses; ++classIndex) { discriminantScores[classIndex] = (this->GetMembershipFunction(classIndex))->Evaluate(inputImagePixel); } classLabel = static_cast(this->GetDecisionRule()->Evaluate(discriminantScores)); outputClassifiedLabel = ClassifiedImagePixelType(classLabel); classifiedIt.Set(outputClassifiedLabel); } } template void ImageClassifierBase::Allocate() { InputImageConstPointer inputImage = this->GetInputImage(); InputImageSizeType inputImageSize = inputImage->GetBufferedRegion().GetSize(); ClassifiedImagePointer classifiedImage = TClassifiedImage::New(); this->SetClassifiedImage(classifiedImage); const typename TClassifiedImage::RegionType classifiedImageRegion(inputImageSize); classifiedImage->SetLargestPossibleRegion(classifiedImageRegion); classifiedImage->SetBufferedRegion(classifiedImageRegion); classifiedImage->Allocate(); } template std::vector ImageClassifierBase::GetPixelMembershipValue(const InputImagePixelType inputImagePixel) { unsigned int numberOfClasses = this->GetNumberOfClasses(); std::vector pixelMembershipValue(numberOfClasses); for (unsigned int classIndex = 0; classIndex < numberOfClasses; ++classIndex) { pixelMembershipValue[classIndex] = (this->GetMembershipFunction(classIndex))->Evaluate(inputImagePixel); } return pixelMembershipValue; } } // namespace itk #endif