diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 418c9052bb06e8cf4c2dd2d557488037c9ca5bff..ef933451346da7e8bfb4513b6b22b7d8ab9f1580 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -41,7 +41,7 @@ pylint: allow_failure: true script: - sudo apt update && sudo apt install pylint -y - - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 $OTBTF_SRC/python + - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/python codespell: stage: Static Analysis @@ -54,7 +54,7 @@ cppcheck: allow_failure: true script: - sudo apt update && sudo apt install cppcheck -y - - cppcheck --enable=all --error-exitcode=1 $OTBTF_SRC/ + - cd $OTBTF_SRC/ && cppcheck --enable=all --error-exitcode=1 -I include/ --suppress=missingInclude --suppress=unusedFunction . ctest: stage: Test diff --git a/app/otbDensePolygonClassStatistics.cxx b/app/otbDensePolygonClassStatistics.cxx index 62f9eea9e87e8b11490360a7d2eeb59ab4371c64..fa7c2701db9bbe959ec2787f8e04c1fd6915cd3e 100644 --- a/app/otbDensePolygonClassStatistics.cxx +++ b/app/otbDensePolygonClassStatistics.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -9,18 +9,24 @@ PURPOSE. See the above copyright notices for more information. =========================================================================*/ -#include "otbWrapperApplication.h" +#include "itkFixedArray.h" +#include "itkObjectFactory.h" #include "otbWrapperApplicationFactory.h" +// Application engine +#include "otbStandardFilterWatcher.h" +#include "itkFixedArray.h" + +// Filters #include "otbStatisticsXMLFileWriter.h" #include "otbWrapperElevationParametersHandler.h" - #include "otbVectorDataToLabelImageFilter.h" #include "otbImageToNoDataMaskFilter.h" #include "otbStreamingStatisticsMapFromLabelImageFilter.h" #include "otbVectorDataIntoImageProjectionFilter.h" #include "otbImageToVectorImageCastFilter.h" +// OGR #include "otbOGR.h" namespace otb @@ -37,15 +43,14 @@ class DensePolygonClassStatistics : public Application { public: /** Standard class typedefs. */ - typedef DensePolygonClassStatistics Self; + typedef DensePolygonClassStatistics Self; typedef Application Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); - - itkTypeMacro(DensePolygonClassStatistics, otb::Application); + itkTypeMacro(DensePolygonClassStatistics, Application); /** DataObjects typedef */ typedef UInt32ImageType LabelImageType; @@ -67,14 +72,7 @@ public: typedef otb::StatisticsXMLFileWriter<FloatVectorImageType::PixelType> StatWriterType; - -private: - DensePolygonClassStatistics() - { - - } - - void DoInit() override + void DoInit() { SetName("DensePolygonClassStatistics"); SetDescription("Computes statistics on a training polygon set."); @@ -88,7 +86,6 @@ private: " - number of samples per geometry\n"); SetDocLimitations("None"); SetDocAuthors("Remi Cresson"); - SetDocSeeAlso(" "); AddDocTag(Tags::Learning); @@ -115,67 +112,11 @@ private: SetDocExampleParameterValue("field", "label"); SetDocExampleParameterValue("out","polygonStat.xml"); - SetOfficialDocLink(); - } - - void DoUpdateParameters() override - { - if ( HasValue("vec") ) - { - std::string vectorFile = GetParameterString("vec"); - ogr::DataSource::Pointer ogrDS = - ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read); - ogr::Layer layer = ogrDS->GetLayer(0); - ogr::Feature feature = layer.ogr().GetNextFeature(); - - ClearChoices("field"); - - for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++) - { - std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef(); - key = item; - std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum); - std::transform(key.begin(), end, key.begin(), tolower); - - OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType(); - - if(fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64) - { - std::string tmpKey="field."+key.substr(0, end - key.begin()); - AddChoice(tmpKey,item); - } - } - } - - // Check that the extension of the output parameter is XML (mandatory for - // StatisticsXMLFileWriter) - // Check it here to trigger the error before polygons analysis - - if ( HasValue("out") ) - { - // Store filename extension - // Check that the right extension is given : expected .xml - const std::string extension = itksys::SystemTools::GetFilenameLastExtension(this->GetParameterString("out")); - - if (itksys::SystemTools::LowerCase(extension) != ".xml") - { - otbAppLogFATAL( << extension << " is a wrong extension for parameter \"out\": Expected .xml" ); - } - } } - void DoExecute() override + void DoExecute() { - // Filters - VectorDataReprojFilterType::Pointer m_VectorDataReprojectionFilter; - RasterizeFilterType::Pointer m_RasterizeFIDFilter; - RasterizeFilterType::Pointer m_RasterizeClassFilter; - NoDataMaskFilterType::Pointer m_NoDataFilter; - CastFilterType::Pointer m_NoDataCastFilter; - StatsFilterType::Pointer m_FIDStatsFilter; - StatsFilterType::Pointer m_ClassStatsFilter; - // Retrieve the field name std::vector<int> selectedCFieldIdx = GetSelectedItems("field"); @@ -254,14 +195,72 @@ private: fidMap.erase(intNoData); classMap.erase(intNoData); - StatWriterType::Pointer statWriter = StatWriterType::New(); - statWriter->SetFileName(this->GetParameterString("out")); - statWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerClass",classMap); - statWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerVector",fidMap); - statWriter->Update(); + m_StatWriter = StatWriterType::New(); + m_StatWriter->SetFileName(this->GetParameterString("out")); + m_StatWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerClass", classMap); + m_StatWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerVector", fidMap); + m_StatWriter->Update(); } - + + void DoUpdateParameters() + { + if (HasValue("vec")) + { + std::string vectorFile = GetParameterString("vec"); + ogr::DataSource::Pointer ogrDS = + ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read); + ogr::Layer layer = ogrDS->GetLayer(0); + ogr::Feature feature = layer.ogr().GetNextFeature(); + + ClearChoices("field"); + + for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++) + { + std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef(); + key = item; + std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum); + std::transform(key.begin(), end, key.begin(), tolower); + + OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType(); + + if(fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64) + { + std::string tmpKey="field."+key.substr(0, end - key.begin()); + AddChoice(tmpKey,item); + } + } + } + + // Check that the extension of the output parameter is XML (mandatory for + // StatisticsXMLFileWriter) + // Check it here to trigger the error before polygons analysis + + if (HasValue("out")) + { + // Store filename extension + // Check that the right extension is given : expected .xml + const std::string extension = itksys::SystemTools::GetFilenameLastExtension(this->GetParameterString("out")); + + if (itksys::SystemTools::LowerCase(extension) != ".xml") + { + otbAppLogFATAL( << extension << " is a wrong extension for parameter \"out\": Expected .xml" ); + } + } + } + + + +private: + // Filters + VectorDataReprojFilterType::Pointer m_VectorDataReprojectionFilter; + RasterizeFilterType::Pointer m_RasterizeFIDFilter; + RasterizeFilterType::Pointer m_RasterizeClassFilter; + NoDataMaskFilterType::Pointer m_NoDataFilter; + CastFilterType::Pointer m_NoDataCastFilter; + StatsFilterType::Pointer m_FIDStatsFilter; + StatsFilterType::Pointer m_ClassStatsFilter; + StatWriterType::Pointer m_StatWriter; }; diff --git a/app/otbImageClassifierFromDeepFeatures.cxx b/app/otbImageClassifierFromDeepFeatures.cxx index 98763559bfebbe710fd3e5a7a9e0b9214bc4f111..f3ffd2731307ecd0e0f7afdb5f6e78b9a4ac4163 100644 --- a/app/otbImageClassifierFromDeepFeatures.cxx +++ b/app/otbImageClassifierFromDeepFeatures.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -61,7 +61,6 @@ private: // Populate group ShareParameter(ss_key_group.str(), "tfmodel." + ss_key_group.str(), ss_desc_group.str()); - } @@ -107,10 +106,8 @@ private: ShareParameter("out" , "classif.out" , "Output image" , "Output image" ); ShareParameter("confmap" , "classif.confmap" , "Confidence map image", "Confidence map image"); ShareParameter("ram" , "classif.ram" , "Ram" , "Ram" ); - } - void DoUpdateParameters() { UpdateInternalParameters("classif"); @@ -122,12 +119,8 @@ private: GetInternalApplication("classif")->SetParameterInputImage("in", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); UpdateInternalParameters("classif"); ExecuteInternal("classif"); - } // DOExecute() - - void AfterExecuteAndWriteOutputs() - { - // Nothing to do } + }; } // namespace Wrapper } // namespace otb diff --git a/app/otbLabelImageSampleSelection.cxx b/app/otbLabelImageSampleSelection.cxx index 40591b9c42d648ec64ebfd90dfc7ccc5755e68ab..50396fa0ac4ded119bd31268327fea4032b15f2f 100644 --- a/app/otbLabelImageSampleSelection.cxx +++ b/app/otbLabelImageSampleSelection.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -35,7 +35,7 @@ class LabelImageSampleSelection : public Application { public: /** Standard class typedefs. */ - typedef LabelImageSampleSelection Self; + typedef LabelImageSampleSelection Self; typedef Application Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; @@ -385,4 +385,4 @@ private: } // end namespace wrapper } // end namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::LabelImageSampleSelection ) +OTB_APPLICATION_EXPORT(otb::Wrapper::LabelImageSampleSelection) diff --git a/app/otbPatchesExtraction.cxx b/app/otbPatchesExtraction.cxx index 1191486620699d49569793cd50b4ef109128f7be..7b0ce4565e687da42bdf9adb71cadd883acd56cc 100644 --- a/app/otbPatchesExtraction.cxx +++ b/app/otbPatchesExtraction.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -136,10 +136,6 @@ public: } } - void DoUpdateParameters() - { - } - void DoInit() { @@ -237,6 +233,12 @@ public: } } + + + void DoUpdateParameters() + { + } + private: std::vector<SourceBundle> m_Bundles; @@ -245,4 +247,4 @@ private: } // end namespace wrapper } // end namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::PatchesExtraction ) +OTB_APPLICATION_EXPORT(otb::Wrapper::PatchesExtraction) diff --git a/app/otbPatchesSelection.cxx b/app/otbPatchesSelection.cxx index eb6a31a89cc456dae932b74cc798462cdec19737..5cbdc82a3e0b2d665d4d332d6cddee979d22e3bc 100644 --- a/app/otbPatchesSelection.cxx +++ b/app/otbPatchesSelection.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -68,7 +68,7 @@ class PatchesSelection : public Application { public: /** Standard class typedefs. */ - typedef PatchesSelection Self; + typedef PatchesSelection Self; typedef Application Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; @@ -100,11 +100,6 @@ public: typedef itk::MaskImageFilter<UInt8ImageType, UInt8ImageType, UInt8ImageType> MaskImageFilterType; - void DoUpdateParameters() - { - } - - void DoInit() { @@ -167,22 +162,15 @@ public: { public: SampleBundle(){} - SampleBundle(unsigned int nClasses){ - dist = DistributionType(nClasses); - id = 0; + explicit SampleBundle(unsigned int nClasses): dist(DistributionType(nClasses)), id(0), black(true){ (void) point; - black = true; (void) index; } ~SampleBundle(){} - SampleBundle(const SampleBundle & other){ - dist = other.GetDistribution(); - id = other.GetSampleID(); - point = other.GetPosition(); - black = other.GetBlack(); - index = other.GetIndex(); - } + SampleBundle(const SampleBundle & other): dist(other.GetDistribution()), id(other.GetSampleID()), + point(other.GetPosition()), black(other.GetBlack()), index(other.GetIndex()) + {} DistributionType GetDistribution() const { @@ -539,7 +527,7 @@ public: PopulateVectorData(seed); } - void PopulateVectorData(std::vector<SampleBundle> & samples) + void PopulateVectorData(const std::vector<SampleBundle> & samples) { // Get data tree DataTreeType::Pointer treeTrain = m_OutVectorDataTrain->GetDataTree(); @@ -657,6 +645,11 @@ public: } + + void DoUpdateParameters() + { + } + private: RadiusType m_Radius; IsNoDataFilterType::Pointer m_NoDataFilter; diff --git a/app/otbTensorflowModelServe.cxx b/app/otbTensorflowModelServe.cxx index f381ea6628809407ba91e05146b229238414e694..47a8c95730443c0c3345933d71f63a1bcbcb556d 100644 --- a/app/otbTensorflowModelServe.cxx +++ b/app/otbTensorflowModelServe.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -62,10 +62,6 @@ public: /** Typedefs for images */ typedef FloatVectorImageType::SizeType SizeType; - void DoUpdateParameters() - { - } - // // Store stuff related to one source // @@ -120,9 +116,12 @@ public: AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str() ); AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); + SetDefaultParameterInt (ss_key_dims_x.str(), 1); AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); + SetDefaultParameterInt (ss_key_dims_y.str(), 1); AddParameter(ParameterType_String, ss_key_ph.str(), ss_desc_ph.str()); + MandatoryOff (ss_key_ph.str()); // Add a new bundle ProcessObjectsBundle bundle; @@ -184,7 +183,7 @@ public: SetDefaultParameterFloat ("output.spcscale", 1.0); SetParameterDescription ("output.spcscale", "The output image size/scale and spacing*scale where size and spacing corresponds to the first input"); AddParameter(ParameterType_StringList, "output.names", "Names of the output tensors"); - MandatoryOn ("output.names"); + MandatoryOff ("output.names"); // Output Field of Expression AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)"); @@ -329,6 +328,11 @@ public: SetParameterOutputImage("out", m_TFFilter->GetOutput()); } } + + + void DoUpdateParameters() + { + } private: diff --git a/app/otbTensorflowModelTrain.cxx b/app/otbTensorflowModelTrain.cxx index ffb88e4c2e32b6d885a0ed882cc0b50a8dfe4320..e79019980537f873efc85cc8dd22e4bd05fce82e 100644 --- a/app/otbTensorflowModelTrain.cxx +++ b/app/otbTensorflowModelTrain.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -362,7 +362,7 @@ public: // // Get user placeholders // - TrainModelFilterType::DictType GetUserPlaceholders(const std::string key) + TrainModelFilterType::DictType GetUserPlaceholders(const std::string & key) { TrainModelFilterType::DictType dict; TrainModelFilterType::StringList expressions = GetParameterStringList(key); diff --git a/app/otbTrainClassifierFromDeepFeatures.cxx b/app/otbTrainClassifierFromDeepFeatures.cxx index ada83c51fdb8674a39e8099120e5f796cd6f8481..39ac41898a450159a4c50133883006c4466bcb51 100644 --- a/app/otbTrainClassifierFromDeepFeatures.cxx +++ b/app/otbTrainClassifierFromDeepFeatures.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -121,11 +121,6 @@ private: GetInternalApplication("train")->AddImageToParameterInputImageList("io.il", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); UpdateInternalParameters("train"); ExecuteInternal("train"); - } // DOExecute() - - void AfterExecuteAndWriteOutputs() - { - // Nothing to do } }; diff --git a/include/otbTensorflowCommon.cxx b/include/otbTensorflowCommon.cxx index 300f4b5a2653ab5863d147e0efb97033457c5d51..662c9d3e979c5e67ccf9effc4564c9d9fd5c6d0e 100644 --- a/include/otbTensorflowCommon.cxx +++ b/include/otbTensorflowCommon.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowCommon.h b/include/otbTensorflowCommon.h index 512ff5d9b86d48c2db27a545ba716d41c6b88a5c..fbd7281035185c2acc3a56dac3850a23d76280df 100644 --- a/include/otbTensorflowCommon.h +++ b/include/otbTensorflowCommon.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowCopyUtils.cxx b/include/otbTensorflowCopyUtils.cxx index 63f4a98cff1df0226fca87e43f7ba86fd6a1edf2..32358ace9af7f5a3004984fd95b4839975bc6a76 100644 --- a/include/otbTensorflowCopyUtils.cxx +++ b/include/otbTensorflowCopyUtils.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -407,7 +407,7 @@ ValueToTensor(std::string value) } idx++; } - itkDebugMacro("Returning tensor: "<< out.DebugString()); + otbLogMacro(Debug, << "Returning tensor: "<< out.DebugString()); return out; } diff --git a/include/otbTensorflowCopyUtils.h b/include/otbTensorflowCopyUtils.h index 91cddef555f2eadb89862e65dc82c989fc2444cc..1bf47dd42b2afa8e55d5323bbc0c3dcb43378f58 100644 --- a/include/otbTensorflowCopyUtils.h +++ b/include/otbTensorflowCopyUtils.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -15,6 +15,9 @@ // ITK exception #include "itkMacro.h" +// OTB log +#include "otbMacro.h" + // ITK image iterators #include "itkImageRegionIterator.h" #include "itkImageRegionConstIterator.h" diff --git a/include/otbTensorflowDataTypeBridge.cxx b/include/otbTensorflowDataTypeBridge.cxx index 0a9eded8239e39021609adea8e9291f3bb22f554..a510cb4ea5ecab0c1505e690d79922b2299ddc0d 100644 --- a/include/otbTensorflowDataTypeBridge.cxx +++ b/include/otbTensorflowDataTypeBridge.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowDataTypeBridge.h b/include/otbTensorflowDataTypeBridge.h index bce9791fbe18744e1367c7cdf7d421a980981441..af6be18d335761b7261e6b8c7288cf9b07122bc8 100644 --- a/include/otbTensorflowDataTypeBridge.h +++ b/include/otbTensorflowDataTypeBridge.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowGraphOperations.cxx b/include/otbTensorflowGraphOperations.cxx index 3a4a4402aff61f18e622e7dc659f194fc12a1b1a..d40c4da6a2f49a86cb28069094b4ab9f0cc5b231 100644 --- a/include/otbTensorflowGraphOperations.cxx +++ b/include/otbTensorflowGraphOperations.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -11,46 +11,53 @@ =========================================================================*/ #include "otbTensorflowGraphOperations.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // // Load SavedModel variables // -void RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) { tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); checkpointPathTensor.scalar<tensorflow::tstring>()() = path; - std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = - {{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; - auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().restore_op_name()}, nullptr); + std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = { + { bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor } + }; + auto status = bundle.session->Run(feed_dict, {}, { bundle.meta_graph_def.saver_def().restore_op_name() }, nullptr); if (!status.ok()) { - itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() ); + itkGenericExceptionMacro("Can't restore the input model: " << status.ToString()); } } // // Save SavedModel variables // -void SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) { tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); checkpointPathTensor.scalar<tensorflow::tstring>()() = path; - std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = - {{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; - auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().save_tensor_name()}, nullptr); + std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = { + { bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor } + }; + auto status = bundle.session->Run(feed_dict, {}, { bundle.meta_graph_def.saver_def().save_tensor_name() }, nullptr); if (!status.ok()) { - itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() ); + itkGenericExceptionMacro("Can't restore the input model: " << status.ToString()); } } // // Load a SavedModel // -void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::vector<std::string> tagList) +void +LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::vector<std::string> tagList) { // If the tag list is empty, we push back the default tag for model serving if (tagList.size() == 0) @@ -63,92 +70,117 @@ void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bu // Call to tensorflow::LoadSavedModel tensorflow::RunOptions runoptions; runoptions.set_trace_level(tensorflow::RunOptions_TraceLevel_FULL_TRACE); - auto status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(), runoptions, - path, tagSets, &bundle); + auto status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(), runoptions, path, tagSets, &bundle); if (!status.ok()) { - itkGenericExceptionMacro("Can't load the input model: " << status.ToString() ); + itkGenericExceptionMacro("Can't load the input model: " << status.ToString()); } } + // Get the following attributes of the specified tensors (by name) of a graph: +// - layer name, as specified in the model // - shape // - datatype -void GetTensorAttributes(const tensorflow::protobuf::Map<std::string, tensorflow::TensorInfo> layers, std::vector<std::string> & tensorsNames, - std::vector<tensorflow::TensorShapeProto> & shapes, std::vector<tensorflow::DataType> & dataTypes) +void +GetTensorAttributes(const tensorflow::protobuf::Map<std::string, tensorflow::TensorInfo> layers, + std::vector<std::string> & tensorsNames, + std::vector<std::string> & layerNames, + std::vector<tensorflow::TensorShapeProto> & shapes, + std::vector<tensorflow::DataType> & dataTypes) { // Allocation shapes.clear(); - shapes.reserve(tensorsNames.size()); dataTypes.clear(); - dataTypes.reserve(tensorsNames.size()); + layerNames.clear(); - itkDebugMacro("Nodes contained in the model: "); - int i = 0; + // Debug infos + otbLogMacro(Debug, << "Nodes contained in the model: "); for (auto const & layer : layers) - { - itkDebugMacro("Node "<< i << " inside the model: " << layer.first); - i+=1; - } + otbLogMacro(Debug, << "\t" << layer.first); - // Get infos - for (std::vector<std::string>::iterator nameIt = tensorsNames.begin(); - nameIt != tensorsNames.end(); ++nameIt) + // When the user doesn't specify output.names, m_OutputTensors defaults to an empty list that we can not iterate over. + // We change it to a list containing an empty string [""] + if (tensorsNames.size() == 0) { - bool found = false; - itkDebugMacro("Searching for corresponding node of: " << (*nameIt) << "... "); - for (auto const & layer : layers) - { - // layer is a pair (name, tensor_info) - // cf https://stackoverflow.com/questions/63181951/how-to-get-graph-or-graphdef-from-a-given-model - std::string layername = layer.first; - if (layername.substr(0, layername.find(":")).compare((*nameIt)) == 0) - { - found = true; - const tensorflow::TensorInfo& tensor_info = layer.second; + otbLogMacro(Debug, << "No output.name specified. Using a default list with one empty string."); + tensorsNames.push_back(""); + } - itkDebugMacro("Found: " << layername << " in the model"); + // Next, we fill layerNames + int k = 0; // counter used for tensorsNames + for (auto const & name: tensorsNames) + { + bool found = false; + tensorflow::TensorInfo tensor_info; - // Set default to DT_FLOAT - tensorflow::DataType ts_dt = tensorflow::DT_FLOAT; + // If the user didn't specify the placeholdername, choose the kth layer inside the model + if (name.size() == 0) + { + found = true; + // select the k-th element of `layers` + auto it = layers.begin(); + std::advance(it, k); + layerNames.push_back(it->second.name()); + tensor_info = it->second; + otbLogMacro(Debug, << "Input " << k << " corresponds to " << it->first << " in the model"); + } - // Default (input?) tensor type - ts_dt = tensor_info.dtype(); - dataTypes.push_back(ts_dt); + // Else, if the user specified the placeholdername, find the corresponding layer inside the model + else + { + otbLogMacro(Debug, << "Searching for corresponding node of: " << name << "... "); + for (auto const & layer : layers) + { + // layer is a pair (name, tensor_info) + // cf https://stackoverflow.com/questions/63181951/how-to-get-graph-or-graphdef-from-a-given-model + std::string layername = layer.first; + if (layername.substr(0, layername.find(":")).compare(name) == 0) + { + found = true; + layerNames.push_back(layer.second.name()); + tensor_info = layer.second; + otbLogMacro(Debug, << "Found: " << layer.second.name() << " in the model"); + } + } // next layer + } // end else - // Get the tensor's shape - // Here we assure it's a tensor, with 1 shape - tensorflow::TensorShapeProto ts_shp = tensor_info.tensor_shape(); - shapes.push_back(ts_shp); - } - } // next layer + k += 1; if (!found) { - itkGenericExceptionMacro("Tensor name \"" << (*nameIt) << "\" not found. \n" << - "You can list all inputs/outputs of your SavedModel by " << - "running: \n\t `saved_model_cli show --dir your_model_dir --all`"); - + itkGenericExceptionMacro("Tensor name \"" << name << "\" not found. \n" + << "You can list all inputs/outputs of your SavedModel by " + << "running: \n\t `saved_model_cli show --dir your_model_dir --all`"); } + + // Default tensor type + tensorflow::DataType ts_dt = tensor_info.dtype(); + dataTypes.push_back(ts_dt); + + // Get the tensor's shape + // Here we assure it's a tensor, with 1 shape + tensorflow::TensorShapeProto ts_shp = tensor_info.tensor_shape(); + shapes.push_back(ts_shp); } // next tensor name } // // Print a lot of stuff about the specified nodes of the graph // -void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & nodesNames) +void +PrintNodeAttributes(const tensorflow::GraphDef & graph, const std::vector<std::string> & nodesNames) { std::cout << "Go through graph:" << std::endl; std::cout << "#\tname" << std::endl; - for (int i = 0 ; i < graph.node_size() ; i++) + for (int i = 0; i < graph.node_size(); i++) { tensorflow::NodeDef node = graph.node(i); std::cout << i << "\t" << node.name() << std::endl; - for (std::vector<std::string>::iterator nameIt = nodesNames.begin(); - nameIt != nodesNames.end(); ++nameIt) + for (auto const & name: nodesNames) { - if (node.name().compare((*nameIt)) == 0) + if (node.name().compare(name) == 0) { std::cout << "Node " << i << " : " << std::endl; std::cout << "\tName: " << node.name() << std::endl; @@ -160,7 +192,7 @@ void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::st // display all attributes of the node std::cout << "\tAttributes of the node: " << std::endl; - for (auto attr = node.attr().begin() ; attr != node.attr().end() ; attr++) + for (auto attr = node.attr().begin(); attr != node.attr().end(); attr++) { std::cout << "\t\tKey: " << attr->first << std::endl; std::cout << "\t\tValue.value_case(): " << attr->second.value_case() << std::endl; @@ -170,9 +202,9 @@ void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::st std::cout << "\t\t-------------------------------------------------" << std::endl; std::cout << std::endl; } // next attribute - } // node name match - } // next node name - } // next node of the graph + } // node name match + } // next node name + } // next node of the graph } } // end namespace tf diff --git a/include/otbTensorflowGraphOperations.h b/include/otbTensorflowGraphOperations.h index 61c92411cc1a5ad9103d4fa25af662190de09f1c..6ad4a4e29880e10a5f10cbc4f1e945db9ca3c6a6 100644 --- a/include/otbTensorflowGraphOperations.h +++ b/include/otbTensorflowGraphOperations.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -24,6 +24,9 @@ // ITK exception #include "itkMacro.h" +// OTB log +#include "otbMacro.h" + namespace otb { namespace tf { @@ -44,7 +47,7 @@ void GetTensorAttributes(const tensorflow::protobuf::Map<std::string, tensorflow std::vector<tensorflow::TensorShapeProto> & shapes, std::vector<tensorflow::DataType> & dataTypes); // Print a lot of stuff about the specified nodes of the graph -void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & nodesNames); +void PrintNodeAttributes(const tensorflow::GraphDef & graph, const std::vector<std::string> & nodesNames); } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelBase.h b/include/otbTensorflowMultisourceModelBase.h index 693a150d68436d851c1f58b475a283d4350ae7ac..d10648ea00afc6fc624fc9ffac8e19bbdffdf4f2 100644 --- a/include/otbTensorflowMultisourceModelBase.h +++ b/include/otbTensorflowMultisourceModelBase.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -131,7 +131,7 @@ public: itkGetMacro(OutputExpressionFields, SizeListType); /** User placeholders */ - void SetUserPlaceholders(DictType dict) {m_UserPlaceholders = dict;} + void SetUserPlaceholders(const DictType & dict) {m_UserPlaceholders = dict;} DictType GetUserPlaceholders() {return m_UserPlaceholders;} /** Target nodes names */ @@ -175,8 +175,9 @@ private: TensorShapeProtoList m_InputTensorsShapes; // Input tensors shapes TensorShapeProtoList m_OutputTensorsShapes; // Output tensors shapes - // Tensor names mapping - std::map<std::string, std::string> m_UserNameToLayerNameMapping; + // Layer names inside the model corresponding to inputs and outputs + StringList m_InputLayers; // List of input names, as contained in the model + StringList m_OutputLayers; // List of output names, as contained in the model }; // end class diff --git a/include/otbTensorflowMultisourceModelBase.hxx b/include/otbTensorflowMultisourceModelBase.hxx index fd3cefc7f08c3830e81493fae56237c9e70cf057..752b7c9d61a861d260dc4dfac89efb66e772a42b 100644 --- a/include/otbTensorflowMultisourceModelBase.hxx +++ b/include/otbTensorflowMultisourceModelBase.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -23,6 +23,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> { Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max() ); Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max() ); + + m_SavedModel = NULL; } template <class TInputImage, class TOutputImage> @@ -32,15 +34,16 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> { auto signatures = this->GetSavedModel()->GetSignatures(); tensorflow::SignatureDef signature_def; - // If serving_default key exists (which is the default for TF saved model), choose it as signature - // Else, choose the first one + if (signatures.size() == 0) { itkExceptionMacro("There are no available signatures for this tag-set. \n" << - "Please check which tag-set to use by running "<< + "Please check which tag-set to use by running "<< "`saved_model_cli show --dir your_model_dir --all`"); } + // If serving_default key exists (which is the default for TF saved model), choose it as signature + // Else, choose the first one if (signatures.contains(tensorflow::kDefaultServingSignatureDefKey)) { signature_def = signatures.at(tensorflow::kDefaultServingSignatureDefKey); @@ -103,33 +106,24 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> { // Add the user's placeholders - for (auto& dict: this->GetUserPlaceholders()) - { - inputs.push_back(dict); - } + std::copy(this->GetUserPlaceholders().begin(), this->GetUserPlaceholders().end(), std::back_inserter(inputs)); // Run the TF session here // The session will initialize the outputs - // Inputs corresponds to the names of placeholder, as specified when calling TensorFlowModelServe application - // Decloud example: For TF1 model, it is specified by the user as "tower_0:s2_t". For TF2 model, it must be specified by the user as "s2_t" - // Thus, for TF2, we must transform that to "serving_default_s2_t" + // `inputs` corresponds to a mapping {name, tensor}, with the name being specified by the user when calling TensorFlowModelServe + // we must adapt it to `inputs_new`, that corresponds to a mapping {layerName, tensor}, with the layerName being from the model DictType inputs_new; + int k = 0; for (auto& dict: inputs) { - DictElementType element = {m_UserNameToLayerNameMapping[dict.first], dict.second}; + DictElementType element = {m_InputLayers[k], dict.second}; inputs_new.push_back(element); - } - - StringList m_OutputTensors_new; - for (auto& name: m_OutputTensors) - { - m_OutputTensors_new.push_back(m_UserNameToLayerNameMapping[name]); + k+=1; } // Run the session, evaluating our output tensors from the graph - auto status = this->GetSavedModel()->session.get()->Run(inputs_new, m_OutputTensors_new, m_TargetNodesNames, &outputs); - + auto status = this->GetSavedModel()->session.get()->Run(inputs_new, m_OutputLayers, m_TargetNodesNames, &outputs); if (!status.ok()) { @@ -162,36 +156,20 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> " and the number of input tensors names is " << m_InputPlaceholders.size()); } - // Check that the number of the following is the same - // - output tensors names - // - output expression fields - if (m_OutputExpressionFields.size() != m_OutputTensors.size()) - { - itkExceptionMacro("Number of output tensors names is " << m_OutputTensors.size() << - " but the number of output fields of expression is " << m_OutputExpressionFields.size()); - } - ////////////////////////////////////////////////////////////////////////////////////////// // Get tensors information ////////////////////////////////////////////////////////////////////////////////////////// // Set all subelement of the model auto signaturedef = this->GetSignatureDef(); - for (auto& output: signaturedef.outputs()) - { - std::string userName = output.first.substr(0, output.first.find(":")); - std::string layerName = output.second.name(); - m_UserNameToLayerNameMapping[userName] = layerName; - } - for (auto& input: signaturedef.inputs()) - { - std::string userName = input.first.substr(0, input.first.find(":")); - std::string layerName = input.second.name(); - m_UserNameToLayerNameMapping[userName] = layerName; - } - // Get input and output tensors datatypes and shapes - tf::GetTensorAttributes(signaturedef.inputs(), m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes); - tf::GetTensorAttributes(signaturedef.outputs(), m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes); + // Given the inputs/outputs names that the user specified, get the names of the inputs/outputs contained in the model + // and other infos (shapes, dtypes) + // For example, for output names specified by the user m_OutputTensors = ['s2t', 's2t_pad'], + // this will return m_OutputLayers = ['PartitionedCall:0', 'PartitionedCall:1'] + // In case the user hasn't named the output, e.g. m_OutputTensors = [''], + // this will return the first output m_OutputLayers = ['PartitionedCall:0'] + tf::GetTensorAttributes(signaturedef.inputs(), m_InputPlaceholders, m_InputLayers, m_InputTensorsShapes, m_InputTensorsDataTypes); + tf::GetTensorAttributes(signaturedef.outputs(), m_OutputTensors, m_OutputLayers, m_OutputTensorsShapes, m_OutputTensorsDataTypes); } diff --git a/include/otbTensorflowMultisourceModelFilter.h b/include/otbTensorflowMultisourceModelFilter.h index 74200eb694ed75d45800f1eb87adcd9e4c9a0f95..0e042941ffbace8a15d1446f0a28c88f21df135f 100644 --- a/include/otbTensorflowMultisourceModelFilter.h +++ b/include/otbTensorflowMultisourceModelFilter.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -26,6 +26,9 @@ #include "itkMetaDataObject.h" #include "otbMetaDataKey.h" +// OTB log +#include "otbMacro.h" + namespace otb { diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx index 22be795bc93231f0fa34761d038bc07669a2a2ca..30c458611bda8182b7ddd188cc5b102a18b4a251 100644 --- a/include/otbTensorflowMultisourceModelFilter.hxx +++ b/include/otbTensorflowMultisourceModelFilter.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -333,7 +333,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> if (!OutputRegionToInputRegion(requestedRegion, inRegion, inputImage) ) { // Image does not overlap requested region: set requested region to null - itkDebugMacro( << "Image #" << i << " :\n" << inRegion << " is outside the requested region"); + otbLogMacro(Debug, << "Image #" << i << " :\n" << inRegion << " is outside the requested region"); inRegion.GetModifiableIndex().Fill(0); inRegion.GetModifiableSize().Fill(0); } diff --git a/include/otbTensorflowMultisourceModelLearningBase.h b/include/otbTensorflowMultisourceModelLearningBase.h index 6e4c571dc305f0cb70a24990f3f64c0adbd25eca..0663f17a3f6367d5f5fe0ebbc76b1ca71d64957d 100644 --- a/include/otbTensorflowMultisourceModelLearningBase.h +++ b/include/otbTensorflowMultisourceModelLearningBase.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -100,7 +100,7 @@ protected: TensorflowMultisourceModelLearningBase(); virtual ~TensorflowMultisourceModelLearningBase() {}; - virtual void GenerateOutputInformation(void); + virtual void GenerateOutputInformation(void) override; virtual void GenerateInputRequestedRegion(); diff --git a/include/otbTensorflowMultisourceModelLearningBase.hxx b/include/otbTensorflowMultisourceModelLearningBase.hxx index 49ba00df91aa28a0ef597e6f4ef1b1b31ab6a7b3..28b2328b8b82c49896ed40a1edd18fba5cebd7a7 100644 --- a/include/otbTensorflowMultisourceModelLearningBase.hxx +++ b/include/otbTensorflowMultisourceModelLearningBase.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelTrain.h b/include/otbTensorflowMultisourceModelTrain.h index 58d02d25306c0729607dc802d9c21c0d18b9ffbe..8ec4c38c369d532a706746c9674197ad766f657b 100644 --- a/include/otbTensorflowMultisourceModelTrain.h +++ b/include/otbTensorflowMultisourceModelTrain.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelTrain.hxx b/include/otbTensorflowMultisourceModelTrain.hxx index 23bb4e57cc58af0b8c4ffa1a5652bddf2cfa3b75..272dd6390668bd233c5ec41b99ff2b088ef313c3 100644 --- a/include/otbTensorflowMultisourceModelTrain.hxx +++ b/include/otbTensorflowMultisourceModelTrain.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelValidate.h b/include/otbTensorflowMultisourceModelValidate.h index ce9fd45c32779cb9d5e056bfb98780854bd953b2..322f6a24e288db9d9acf202e72ffe04ff8d8a8d4 100644 --- a/include/otbTensorflowMultisourceModelValidate.h +++ b/include/otbTensorflowMultisourceModelValidate.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelValidate.hxx b/include/otbTensorflowMultisourceModelValidate.hxx index d264ebb18eb463f27ea24400fe20ab7aa24280a1..8ec685ba81c1ae51111a077e8170dd227be7241e 100644 --- a/include/otbTensorflowMultisourceModelValidate.hxx +++ b/include/otbTensorflowMultisourceModelValidate.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowSampler.h b/include/otbTensorflowSampler.h index ffd6af614475eb0d55447ebd7ccc59e9311107ff..bd363bc8ee191ce7506ca012b68171f6d8bdc828 100644 --- a/include/otbTensorflowSampler.h +++ b/include/otbTensorflowSampler.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowSampler.hxx b/include/otbTensorflowSampler.hxx index cdcbf1fe0097c355b4e2eec79ccabd8b1bfc0791..8c0ea7459ad5e1ac0060438e3c6a73b760fc535a 100644 --- a/include/otbTensorflowSampler.hxx +++ b/include/otbTensorflowSampler.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowSamplingUtils.cxx b/include/otbTensorflowSamplingUtils.cxx index 5a68cdb2e0ac78568a2a54ef7b7e747c60a9b1f4..5cf88f6b171b61c9576b4ca68d855a0d6059d42f 100644 --- a/include/otbTensorflowSamplingUtils.cxx +++ b/include/otbTensorflowSamplingUtils.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowSamplingUtils.h b/include/otbTensorflowSamplingUtils.h index ab808a257449ad9f810bd5051e2c54a5356b02d6..585f90132ea71509fe2a08ff8f56aa1eac2abb3f 100644 --- a/include/otbTensorflowSamplingUtils.h +++ b/include/otbTensorflowSamplingUtils.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -27,23 +27,17 @@ public: typedef typename TImage::PixelType ValueType; typedef vnl_vector<float> CountsType; - Distribution(unsigned int nClasses){ - m_NbOfClasses = nClasses; - m_Dist = CountsType(nClasses, 0); - + explicit Distribution(unsigned int nClasses): m_NbOfClasses(nClasses), m_Dist(CountsType(nClasses, 0)) + { } - Distribution(unsigned int nClasses, float fillValue){ - m_NbOfClasses = nClasses; - m_Dist = CountsType(nClasses, fillValue); - + Distribution(unsigned int nClasses, float fillValue): m_NbOfClasses(nClasses), m_Dist(CountsType(nClasses, fillValue)) + { } - Distribution(){ - m_NbOfClasses = 2; - m_Dist = CountsType(m_NbOfClasses, 0); + Distribution(): m_NbOfClasses(2), m_Dist(CountsType(m_NbOfClasses, 0)) + { } - Distribution(const Distribution & other){ - m_Dist = other.Get(); - m_NbOfClasses = m_Dist.size(); + Distribution(const Distribution & other): m_Dist(other.Get()), m_NbOfClasses(m_Dist.size()) + { } ~Distribution(){} diff --git a/include/otbTensorflowSource.h b/include/otbTensorflowSource.h index ddec24b50d40b3db32588e5cbbbef44d7257f7d0..1556997f9a20c02c1f5f9fd80f92c0fc38270657 100644 --- a/include/otbTensorflowSource.h +++ b/include/otbTensorflowSource.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -25,9 +25,9 @@ namespace otb { /* - * This is a simple helper to create images concatenation. + * This is a helper for images concatenation. * Images must have the same size. - * This is basically the common input type used in every OTB-TF applications. + * This is the common input type used in every OTB-TF applications. */ template<class TImage> class TensorflowSource @@ -60,7 +60,7 @@ public: // Get the source output FloatVectorImagePointerType Get(); - TensorflowSource(){}; + TensorflowSource(); virtual ~TensorflowSource (){}; private: diff --git a/include/otbTensorflowSource.hxx b/include/otbTensorflowSource.hxx index deb8abd2e56ec2d520656c43caddf02dee6f90a5..2ad575866fda4ba7baa9dc494c44bcc38b79f0cd 100644 --- a/include/otbTensorflowSource.hxx +++ b/include/otbTensorflowSource.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -17,12 +17,21 @@ namespace otb { +// +// Constructor +// +template <class TImage> +TensorflowSource<TImage> +::TensorflowSource() +{} + // // Prepare the big stack of images // template <class TImage> void -TensorflowSource<TImage>::Set(FloatVectorImageListType * inputList) +TensorflowSource<TImage> +::Set(FloatVectorImageListType * inputList) { // Create one stack for input images list m_Concatener = ListConcatenerFilterType::New(); diff --git a/include/otbTensorflowStreamerFilter.h b/include/otbTensorflowStreamerFilter.h index 6aba8656a1119f10968876304755f420f3768be7..4730d3691cd3bd64954091c3bfdbb5bb7422d870 100644 --- a/include/otbTensorflowStreamerFilter.h +++ b/include/otbTensorflowStreamerFilter.h @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowStreamerFilter.hxx b/include/otbTensorflowStreamerFilter.hxx index b323047b0d90b791f4bdf7b23f66af25d7de92c3..59904a54f3df99048dfa383e22dad0ee7bef9784 100644 --- a/include/otbTensorflowStreamerFilter.hxx +++ b/include/otbTensorflowStreamerFilter.hxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py index 1892757e144fb6854df4356d9bd74af96f71374c..117203bafd89bcbfaa272952323434dac4046a8b 100755 --- a/python/ckpt2savedmodel.py +++ b/python/ckpt2savedmodel.py @@ -28,6 +28,7 @@ keras in Tensorflow 2). import argparse from tricks import ckpt_to_savedmodel + def main(): """ Main function @@ -48,5 +49,6 @@ def main(): savedmodel_path=params.model, clear_devices=params.clear_devices) + if __name__ == "__main__": main() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f7716a138fc79aa34c9734e8b5d3a44cfce35b39..eaf6dbd718bbaf1cfe4eee1872cd8d6880f58052 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -49,6 +49,7 @@ set(MODEL4_FC_OUT apTvClTensorflowModelServeFCNN64x64to32x32.tif) set(MODEL1_SAVED model1_updated) #----------- Model training : 1-branch CNN (16x16) Patch-Based ---------------- +set(ENV{OTB_LOGGER_LEVEL} DEBUG) otb_test_application(NAME TensorflowModelTrainCNN16x16PB APP TensorflowModelTrain OPTIONS @@ -68,6 +69,7 @@ otb_test_application(NAME TensorflowModelTrainCNN16x16PB -training.targetnodes "optimizer" -validation.mode "class" ) +set_tests_properties(TensorflowModelTrainCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;$ENV{OTB_LOGGER_LEVEL}") #----------- Model serving : 1-branch CNN (16x16) Patch-Based ---------------- otb_test_application(NAME TensorflowModelServeCNN16x16PB @@ -79,6 +81,7 @@ otb_test_application(NAME TensorflowModelServeCNN16x16PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL1_PB_OUT} ${TEMP}/${MODEL1_PB_OUT}) +set_tests_properties(TensorflowModelServeCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") #----------- Model serving : 2-branch CNN (8x8, 32x32) Patch-Based ---------------- otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32PB @@ -92,7 +95,8 @@ otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL2_PB_OUT} ${TEMP}/${MODEL2_PB_OUT}) -set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32PB PROPERTIES ENVIRONMENT "OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") +set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") + #----------- Model serving : 2-branch CNN (8x8, 32x32) Fully-Conv ---------------- set(ENV{OTB_TF_NSOURCES} 2) @@ -107,7 +111,7 @@ otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32FC VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL2_FC_OUT} ${TEMP}/${MODEL2_FC_OUT}) -set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32FC PROPERTIES ENVIRONMENT "OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") +set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32FC PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") #----------- Model serving : 1-branch FCNN (16x16) Patch-Based ---------------- set(ENV{OTB_TF_NSOURCES} 1) @@ -120,6 +124,8 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN16x16PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL3_PB_OUT} ${TEMP}/${MODEL3_PB_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") + #----------- Model serving : 1-branch FCNN (16x16) Fully-conv ---------------- set(ENV{OTB_TF_NSOURCES} 1) @@ -132,10 +138,11 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN16x16FC VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL3_FC_OUT} ${TEMP}/${MODEL3_FC_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN16x16FC PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") #----------- Model serving : 1-branch FCNN (64x64)-->(32x32), Fully-conv ---------------- set(ENV{OTB_TF_NSOURCES} 1) -otb_test_application(NAME apTvClTensorflowModelServeFCNN64x64to32x32.tif +otb_test_application(NAME apTvClTensorflowModelServeFCNN64x64to32x32 APP TensorflowModelServe OPTIONS -source1.il ${IMAGEPXS2} -source1.rfieldx 64 -source1.rfieldy 64 -source1.placeholder x @@ -145,5 +152,6 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN64x64to32x32.tif VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL4_FC_OUT} ${TEMP}/${MODEL4_FC_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN64x64to32x32 PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") diff --git a/test/otbTensorflowCopyUtilsTests.cxx b/test/otbTensorflowCopyUtilsTests.cxx index 249f16e6368c7ac87e82f8c81ebbb185cf064c13..5b9586460aa1e0cc51a115ea1a6031ebd3cf805e 100644 --- a/test/otbTensorflowCopyUtilsTests.cxx +++ b/test/otbTensorflowCopyUtilsTests.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/test/otbTensorflowTests.cxx b/test/otbTensorflowTests.cxx index c49891210bd8a989947ffd655d35850207555221..50e9a91a57b85feecccf44b49b5c3b57f7e69ec3 100644 --- a/test/otbTensorflowTests.cxx +++ b/test/otbTensorflowTests.cxx @@ -1,7 +1,7 @@ /*========================================================================= Copyright (c) 2018-2019 IRSTEA - Copyright (c) 2020-2020 INRAE + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -21,4 +21,3 @@ void RegisterTests() REGISTER_TEST(boolVecValueToTensorTest); } -