diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..0b9a40282f8a247085847be735807513844c0301 --- /dev/null +++ b/.clang-format @@ -0,0 +1,134 @@ +BasedOnStyle: Mozilla +Language: Cpp +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: true +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: true +# clang 9.0 AllowAllArgumentsOnNextLine: true +# clang 9.0 AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +# clang 9.0 AllowShortLambdasOnASingleLine: All +# clang 9.0 features AllowShortIfStatementsOnASingleLine: Never +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: All +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakBeforeBraces: Custom +BraceWrapping: + # clang 9.0 feature AfterCaseLabel: false + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true +## This is the big change from historical ITK formatting! +# Historically ITK used a style similar to https://en.wikipedia.org/wiki/Indentation_style#Whitesmiths_style +# with indented braces, and not indented code. This style is very difficult to automatically +# maintain with code beautification tools. Not indenting braces is more common among +# formatting tools. + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +#clang 6.0 BreakBeforeInheritanceComma: true +BreakInheritanceList: BeforeComma +BreakBeforeTernaryOperators: true +#clang 6.0 BreakConstructorInitializersBeforeComma: true +BreakConstructorInitializers: BeforeComma +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +## The following line allows larger lines in non-documentation code +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 2 +ContinuationIndentWidth: 2 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: true +IndentPPDirectives: AfterHash +IndentWidth: 2 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: false +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +## The following line allows larger lines in non-documentation code +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Middle +ReflowComments: true +# We may want to sort the includes as a separate pass +SortIncludes: false +# We may want to revisit this later +SortUsingDeclarations: false +SpaceAfterCStyleCast: false +# SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 2 +UseTab: Never +... diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..271b1daf7274419b0577a9dc37a862c8b3eab5ab --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,92 @@ +image: gitlab-registry.irstea.fr/remi.cresson/otbtf:2.4-cpu-basic-testing + +variables: + OTB_BUILD: /src/otb/build/OTB/build # Local OTB build directory + OTBTF_SRC: /src/otbtf # Local OTBTF source directory + +workflow: + rules: + - if: $CI_MERGE_REQUEST_ID # Execute jobs in merge request context + - if: $CI_COMMIT_BRANCH == 'develop' # Execute jobs when a new commit is pushed to develop branch + +stages: + - Build + - Static Analysis + - Test + +.update_otbtf_src: &update_otbtf_src + - sudo rm -rf $OTBTF_SRC && sudo ln -s $PWD $OTBTF_SRC # Replace local OTBTF source directory + +.compile_otbtf: &compile_otbtf + - cd $OTB_BUILD && sudo make install -j$(nproc --all) # Rebuild OTB with new OTBTF sources + +before_script: + - *update_otbtf_src + +build: + stage: Build + allow_failure: false + script: + - *compile_otbtf + +flake8: + stage: Static Analysis + allow_failure: true + script: + - sudo apt update && sudo apt install flake8 -y + - python -m flake8 --max-line-length=120 $OTBTF_SRC/python + +pylint: + stage: Static Analysis + allow_failure: true + script: + - sudo apt update && sudo apt install pylint -y + - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/python + +codespell: + stage: Static Analysis + allow_failure: true + script: + - sudo pip install codespell && codespell + +cppcheck: + stage: Static Analysis + allow_failure: true + script: + - sudo apt update && sudo apt install cppcheck -y + - cd $OTBTF_SRC/ && cppcheck --enable=all --error-exitcode=1 -I include/ --suppress=missingInclude --suppress=unusedFunction . + +ctest: + stage: Test + script: + - *compile_otbtf + - sudo rm -rf $OTB_BUILD/Testing/Temporary/* # Empty testing temporary folder (old files here) + - cd $OTB_BUILD/ && sudo ctest -L OTBTensorflow # Run ctest + after_script: + - cp -r $OTB_BUILD/Testing/Temporary $CI_PROJECT_DIR/testing # Copy artifacts (they must be in $CI_PROJECT_DIR) + artifacts: + paths: + - testing/*.* + expire_in: 1 week + when: on_failure + +sr4rs: + stage: Test + script: + - *compile_otbtf + - pip3 install pytest pytest-cov + - cd $CI_PROJECT_DIR + - wget -O sr4rs_sentinel2_bands4328_france2020_savedmodel.zip + https://nextcloud.inrae.fr/s/boabW9yCjdpLPGX/download/sr4rs_sentinel2_bands4328_france2020_savedmodel.zip + - unzip -o sr4rs_sentinel2_bands4328_france2020_savedmodel.zip + - wget -O sr4rs_data.zip https://nextcloud.inrae.fr/s/qMLLyKCDieqmgWz/download + - unzip -o sr4rs_data.zip + - rm -rf sr4rs + - git clone https://github.com/remicres/sr4rs.git + - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs + - python -m pytest --junitxml=$CI_PROJECT_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py + artifacts: + when: on_failure + paths: + - $CI_PROJECT_DIR/report_sr4rs.xml + expire_in: 1 week diff --git a/CMakeLists.txt b/CMakeLists.txt index f0bc92a6ca6ae979dbf029c86285932851460746..0a4d646a95ba9bf3cae8b93687b0a72522060dc3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ if(OTB_USE_TENSORFLOW) find_library(TENSORFLOW_FRAMEWORK_LIB NAMES libtensorflow_framework) set(TENSORFLOW_LIBS "${TENSORFLOW_CC_LIB}" "${TENSORFLOW_FRAMEWORK_LIB}") - + set(OTBTensorflow_THIRD_PARTY "this is a hack to skip header_tests") else() message("Tensorflow support disabled") endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000000000000000000000000000000000000..cd39540e7c15ad843a6f52502ea717faf1c0e38a --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,8 @@ +- Remi Cresson +- Nicolas Narcon +- Benjamin Commandre +- Vincent Delbar +- Loic Lozac'h +- Pratyush Das +- Doctor Who +- Jordi Inglada diff --git a/Dockerfile b/Dockerfile index 53e3f777910b643ae8780a14a3d7bbeb64342bf9..c9aa6e9fcb40cba228c9a4d2c0634eb055e38fdf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,7 +47,7 @@ RUN wget -qO /opt/otbtf/bin/bazelisk https://github.com/bazelbuild/bazelisk/rele && ln -s /opt/otbtf/bin/bazelisk /opt/otbtf/bin/bazel ARG BZL_TARGETS="//tensorflow:libtensorflow_cc.so //tensorflow/tools/pip_package:build_pip_package" -# "--config=opt" will enable 'march=native' (otherwise read comments about CPU compatibilty and edit CC_OPT_FLAGS in build-env-tf.sh) +# "--config=opt" will enable 'march=native' (otherwise read comments about CPU compatibility and edit CC_OPT_FLAGS in build-env-tf.sh) ARG BZL_CONFIGS="--config=nogcp --config=noaws --config=nohdfs --config=opt" # "--compilation_mode opt" is already enabled by default (see tf repo .bazelrc and configure.py) ARG BZL_OPTIONS="--verbose_failures --remote_cache=http://localhost:9090" @@ -77,7 +77,7 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi # Symlink external libs (required for MKL - libiomp5) && for f in $(find -L /opt/otbtf/include/tf -wholename "*/external/*/*.so"); do ln -s $f /opt/otbtf/lib/; done \ # Compress and save TF binaries - && ( ! $ZIP_TF_BIN || zip -9 -j --symlinks /opt/otbtf/tf-$TF.zip tensorflow/cc/saved_model/tag_constants.h bazel-bin/tensorflow/libtensorflow_cc.so* /tmp/tensorflow_pkg/tensorflow*.whl ) \ + && ( ! $ZIP_TF_BIN || zip -9 -j --symlinks /opt/otbtf/tf-$TF.zip tensorflow/cc/saved_model/tag_constants.h tensorflow/cc/saved_model/signature_constants.h bazel-bin/tensorflow/libtensorflow_cc.so* /tmp/tensorflow_pkg/tensorflow*.whl ) \ # Cleaning && rm -rf bazel-* /src/tf /root/.cache/ /tmp/* diff --git a/LICENSE b/LICENSE index a1a86c14aa32ced5bb3489db5ce76447e7e18bda..236ca02805e38ec6a5a3112472a62879155b1846 100644 --- a/LICENSE +++ b/LICENSE @@ -188,7 +188,7 @@ identification within third-party archives. Copyright 2018-2019 Rémi Cresson (IRSTEA) - Copyright 2020 Rémi Cresson (INRAE) + Copyright 2020-2021 Rémi Cresson (INRAE) Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 2ce6452f57996552da5abf006d809e7a73d98a19..a18961b3038b680b4a52f86666bc2ef8005a18d9 100644 --- a/README.md +++ b/README.md @@ -42,8 +42,8 @@ For now you have two options: either use the existing **docker image**, or build Use the latest image from dockerhub: ``` -docker pull mdl4eo/otbtf2.4:cpu -docker run -u otbuser -v $(pwd):/home/otbuser mdl4eo/otbtf2.4:cpu otbcli_PatchesExtraction -help +docker pull mdl4eo/otbtf2.5:cpu +docker run -u otbuser -v $(pwd):/home/otbuser mdl4eo/otbtf2.5:cpu otbcli_PatchesExtraction -help ``` Read more in the [docker use documentation](doc/DOCKERUSE.md). diff --git a/RELEASE_NOTES.txt b/RELEASE_NOTES.txt new file mode 100644 index 0000000000000000000000000000000000000000..764033051413016bf2de6441b15a915fe76840a0 --- /dev/null +++ b/RELEASE_NOTES.txt @@ -0,0 +1,99 @@ +Version 3.0.0-beta (20 nov 2021) +---------------------------------------------------------------- +* Use Tensorflow 2 API everywhere. Everything is backward compatible (old models can still be used). +* Support models with no-named inputs and outputs. OTBTF now can resolve the names! :) Just in the same order as they are defined in the computational graph. +* Support user placeholders of type vector (int, float or bool) +* More unit tests, spell check, better static analysis of C++ and python code +* Improve the handling of 3-dimensional output tensors, + more explanation in error messages about output tensors dimensions. +* Improve `PatchesSelection` to locate patches centers with corners or pixels centers depending if the patch size is odd or even. + +Version 2.5 (20 oct 2021) +---------------------------------------------------------------- +* Fix a bug in otbtf.py. The `PatchesImagesReader` wasn't working properly when the streaming was disabled. +* Improve the documentation on docker build and docker use (Thanks to Vincent@LaTelescop and Doctor-Who). + +Version 2.4 (11 apr 2021) +---------------------------------------------------------------- +* Fix a bug: The output image origin was sometimes shifted from a fraction of pixel. This issue happened only with multi-inputs models that have inputs of different spacing. +* Improvement: The output image largest possible region is now computed on the maximum possible area within the expression field. Before that, the largest possible region was too much cropped when an expression field > 1 was used. Now output images are a bit larger when a non unitary expression field is used. + +Version 2.3 (30 mar 2021) +---------------------------------------------------------------- +* More supported numeric types for tensors: + * `tensorflow::DT_FLOAT` + * `tensorflow::DT_DOUBLE` + * `tensorflow::DT_UINT64` + * `tensorflow::DT_INT64` + * `tensorflow::DT_UINT32` + * `tensorflow::DT_INT32` + * `tensorflow::DT_UINT16` + * `tensorflow::DT_INT16` + * `tensorflow::DT_UINT8` +* Update instructions to use docker + +Version 2.2 (29 jan 2021) +---------------------------------------------------------------- +* Huge enhancement of the docker image build (from Vincent@LaTeleScop) + +Version 2.1 (17 nov 2020) +---------------------------------------------------------------- +* New OTBTF python classes to train the models: + * `PatchesReaderBase`: base abstract class for patches readers. Users/developers can implement their own from it! + * `PatchesImagesReader`: a class implementing `PatchesReaderBase` to access the patches images, as they are produced by the OTBTF PatchesExtraction application. + * `IteratorBase`: base class to iterate on `PatchesReaderBase`-derived readers. + * `RandomIterator`: an iterator implementing `IteratorBase` designed to randomly access elements. + * `Dataset`: generic class to build datasets, consisting essentially of the assembly of a `PatchesReaderBase`-derived reader, and a `IteratorBase`-derived iterator. The `Dataset` handles the gathering of the data using a thread. It can be used as a `tf.dataset` to feed computational graphs. + * `DatasetFromPatchesImages`: a `Dataset` that uses a `PatchesImagesReader` to allow users/developers to stream their patches generated using the OTBTF PatchesExtraction through a `tf.dataset` which implements a streaming mechanism, enabling low memory footprint and high performance I/O thank to a threaded reading mechanism. +* Fix in dockerfile (from Pratyush Das) to work with WSL2 + +Version 2.0 (29 may 2020) +---------------------------------------------------------------- +* Now using TensorFlow 2.0! Some minor migration of python models, because we stick with `tf.compat.v1`. +* Python functions to read patches now use GDAL +* Lighter docker images (thanks to Vincent@LaTeleScop) + +Version 1.8.0 (14 jan 2020) +---------------------------------------------------------------- +* PatchesExtraction supports no-data (a different value for each source can be set) +* New sampling strategy available in PatchesSelection (balanced strategy) + +Version 1.7.0 (15 oct 2019) +---------------------------------------------------------------- +* Add a new application for patches selection (experimental) +* New docker images that are GPU-enabled using NVIDIA runtime + +Version 1.6.0 (18 jul 2019) +---------------------------------------------------------------- +* Fix a bug related to coordinates tolerance (TensorflowModelTrain can now use patches that do not occupy physically the same space) +* Fix dockerfile (add important environment variables, add a non-root user, add an example how to run the docker image) +* Document the provided Gaetano et al. two-branch CNN + +Version 1.5.1 (18 jun 2019) +---------------------------------------------------------------- +* Ubuntu bionic dockerfile + instructions +* Doc tags for QGIS3 integration +* Add cmake tests (3 models are tested in various configuration on Pan/XS images) +* PatchesExtraction writes patches images with physical spacing + +Version 1.3.0 (18 nov 2018) +---------------------------------------------------------------- +* Add 3 models that can be directly trained with TensorflowModelTrain (one CNN net, one FCN net, one 2-branch CNN net performing separately on PAN and XS images) +* Fix a bug occurring when using a scale factor <> 1 with a non-unit expression field +* Fix incorrect batch size in learning filters when batch size was not a multiple of number of batches +* Add some documentation + +Version 1.2.0 (29 sep 2018) +---------------------------------------------------------------- +* Fix typos in documentation +* Add a new application for dense polygon classes statistics +* Fix a bug in validation step +* Add streaming option of training/validation +* Document filters classes +* Change applications parameters names and roles +* Add a python application that converts a graph into a savedmodel +* Adjust tiling to expression field +* Update license + +Version 1.0.0 (16 may 2018) +---------------------------------------------------------------- +* First release of OTBTF! diff --git a/app/otbDensePolygonClassStatistics.cxx b/app/otbDensePolygonClassStatistics.cxx index 8ee2283e2edc542d4a389756d23de48c9a00c367..1b9b53f699262a0dc8fcd9f8718332043c917fd6 100644 --- a/app/otbDensePolygonClassStatistics.cxx +++ b/app/otbDensePolygonClassStatistics.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -8,18 +9,24 @@ PURPOSE. See the above copyright notices for more information. =========================================================================*/ -#include "otbWrapperApplication.h" +#include "itkFixedArray.h" +#include "itkObjectFactory.h" #include "otbWrapperApplicationFactory.h" +// Application engine +#include "otbStandardFilterWatcher.h" +#include "itkFixedArray.h" + +// Filters #include "otbStatisticsXMLFileWriter.h" #include "otbWrapperElevationParametersHandler.h" - #include "otbVectorDataToLabelImageFilter.h" #include "otbImageToNoDataMaskFilter.h" #include "otbStreamingStatisticsMapFromLabelImageFilter.h" #include "otbVectorDataIntoImageProjectionFilter.h" #include "otbImageToVectorImageCastFilter.h" +// OGR #include "otbOGR.h" namespace otb @@ -27,82 +34,73 @@ namespace otb namespace Wrapper { /** Utility function to negate std::isalnum */ -bool IsNotAlphaNum(char c) - { +bool +IsNotAlphaNum(char c) +{ return !std::isalnum(c); - } +} class DensePolygonClassStatistics : public Application { public: /** Standard class typedefs. */ - typedef DensePolygonClassStatistics Self; + typedef DensePolygonClassStatistics Self; typedef Application Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); - - itkTypeMacro(DensePolygonClassStatistics, otb::Application); + itkTypeMacro(DensePolygonClassStatistics, Application); /** DataObjects typedef */ - typedef UInt32ImageType LabelImageType; - typedef UInt8ImageType MaskImageType; - typedef VectorData<> VectorDataType; + typedef UInt32ImageType LabelImageType; + typedef UInt8ImageType MaskImageType; + typedef VectorData<> VectorDataType; /** ProcessObjects typedef */ - typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, - FloatVectorImageType> VectorDataReprojFilterType; + typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, FloatVectorImageType> VectorDataReprojFilterType; - typedef otb::VectorDataToLabelImageFilter<VectorDataType, LabelImageType> RasterizeFilterType; + typedef otb::VectorDataToLabelImageFilter<VectorDataType, LabelImageType> RasterizeFilterType; typedef otb::VectorImage<MaskImageType::PixelType> InternalMaskImageType; typedef otb::ImageToNoDataMaskFilter<FloatVectorImageType, MaskImageType> NoDataMaskFilterType; typedef otb::ImageToVectorImageCastFilter<MaskImageType, InternalMaskImageType> CastFilterType; - typedef otb::StreamingStatisticsMapFromLabelImageFilter<InternalMaskImageType, - LabelImageType> StatsFilterType; + typedef otb::StreamingStatisticsMapFromLabelImageFilter<InternalMaskImageType, LabelImageType> StatsFilterType; - typedef otb::StatisticsXMLFileWriter<FloatVectorImageType::PixelType> StatWriterType; + typedef otb::StatisticsXMLFileWriter<FloatVectorImageType::PixelType> StatWriterType; - -private: - DensePolygonClassStatistics() - { - - } - - void DoInit() override + void + DoInit() { SetName("DensePolygonClassStatistics"); SetDescription("Computes statistics on a training polygon set."); // Documentation SetDocLongDescription("The application processes a dense set of polygons " - "intended for training (they should have a field giving the associated " - "class). The geometries are analyzed against a support image to compute " - "statistics : \n" - " - number of samples per class\n" - " - number of samples per geometry\n"); + "intended for training (they should have a field giving the associated " + "class). The geometries are analyzed against a support image to compute " + "statistics : \n" + " - number of samples per class\n" + " - number of samples per geometry\n"); SetDocLimitations("None"); SetDocAuthors("Remi Cresson"); - SetDocSeeAlso(" "); AddDocTag(Tags::Learning); - AddParameter(ParameterType_InputImage, "in", "Input image"); + AddParameter(ParameterType_InputImage, "in", "Input image"); SetParameterDescription("in", "Support image that will be classified"); - + AddParameter(ParameterType_InputVectorData, "vec", "Input vectors"); - SetParameterDescription("vec","Input geometries to analyze"); - + SetParameterDescription("vec", "Input geometries to analyze"); + AddParameter(ParameterType_OutputFilename, "out", "Output XML statistics file"); - SetParameterDescription("out","Output file to store statistics (XML format)"); + SetParameterDescription("out", "Output file to store statistics (XML format)"); AddParameter(ParameterType_ListView, "field", "Field Name"); - SetParameterDescription("field","Name of the field carrying the class number in the input vectors."); - SetListViewSingleSelectionMode("field",true); + SetParameterDescription("field", "Name of the field carrying the class number in the input vectors."); + SetListViewSingleSelectionMode("field", true); ElevationParametersHandler::AddElevationParameters(this, "elev"); @@ -112,156 +110,154 @@ private: SetDocExampleParameterValue("in", "support_image.tif"); SetDocExampleParameterValue("vec", "variousVectors.shp"); SetDocExampleParameterValue("field", "label"); - SetDocExampleParameterValue("out","polygonStat.xml"); + SetDocExampleParameterValue("out", "polygonStat.xml"); + } + + void + DoExecute() + { + + // Retrieve the field name + std::vector<int> selectedCFieldIdx = GetSelectedItems("field"); - SetOfficialDocLink(); + if (selectedCFieldIdx.empty()) + { + otbAppLogFATAL(<< "No field has been selected for data labelling!"); + } + + std::vector<std::string> cFieldNames = GetChoiceNames("field"); + std::string fieldName = cFieldNames[selectedCFieldIdx.front()]; + + otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this, "elev"); + + // Get inputs + FloatVectorImageType::Pointer xs = GetParameterImage("in"); + VectorDataType * shp = GetParameterVectorData("vec"); + + // Reproject vector data + m_VectorDataReprojectionFilter = VectorDataReprojFilterType::New(); + m_VectorDataReprojectionFilter->SetInputVectorData(shp); + m_VectorDataReprojectionFilter->SetInputImage(xs); + m_VectorDataReprojectionFilter->Update(); + + // Internal no-data value + const LabelImageType::ValueType intNoData = itk::NumericTraits<LabelImageType::ValueType>::max(); + + // Rasterize vector data (geometry ID) + m_RasterizeFIDFilter = RasterizeFilterType::New(); + m_RasterizeFIDFilter->AddVectorData(m_VectorDataReprojectionFilter->GetOutput()); + m_RasterizeFIDFilter->SetOutputOrigin(xs->GetOrigin()); + m_RasterizeFIDFilter->SetOutputSpacing(xs->GetSignedSpacing()); + m_RasterizeFIDFilter->SetOutputSize(xs->GetLargestPossibleRegion().GetSize()); + m_RasterizeFIDFilter->SetBurnAttribute("________"); // Trick to get the polygon ID + m_RasterizeFIDFilter->SetGlobalWarningDisplay(false); + m_RasterizeFIDFilter->SetOutputProjectionRef(xs->GetProjectionRef()); + m_RasterizeFIDFilter->SetBackgroundValue(intNoData); + m_RasterizeFIDFilter->SetDefaultBurnValue(0); + + // Rasterize vector data (geometry class) + m_RasterizeClassFilter = RasterizeFilterType::New(); + m_RasterizeClassFilter->AddVectorData(m_VectorDataReprojectionFilter->GetOutput()); + m_RasterizeClassFilter->SetOutputOrigin(xs->GetOrigin()); + m_RasterizeClassFilter->SetOutputSpacing(xs->GetSignedSpacing()); + m_RasterizeClassFilter->SetOutputSize(xs->GetLargestPossibleRegion().GetSize()); + m_RasterizeClassFilter->SetBurnAttribute(fieldName); + m_RasterizeClassFilter->SetOutputProjectionRef(xs->GetProjectionRef()); + m_RasterizeClassFilter->SetBackgroundValue(intNoData); + m_RasterizeClassFilter->SetDefaultBurnValue(0); + + // No data mask + m_NoDataFilter = NoDataMaskFilterType::New(); + m_NoDataFilter->SetInput(xs); + m_NoDataCastFilter = CastFilterType::New(); + m_NoDataCastFilter->SetInput(m_NoDataFilter->GetOutput()); + + // Stats (geometry ID) + m_FIDStatsFilter = StatsFilterType::New(); + m_FIDStatsFilter->SetInput(m_NoDataCastFilter->GetOutput()); + m_FIDStatsFilter->SetInputLabelImage(m_RasterizeFIDFilter->GetOutput()); + m_FIDStatsFilter->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram")); + AddProcess(m_FIDStatsFilter->GetStreamer(), "Computing number of samples per vector"); + m_FIDStatsFilter->Update(); + + // Stats (geometry class) + m_ClassStatsFilter = StatsFilterType::New(); + m_ClassStatsFilter->SetInput(m_NoDataCastFilter->GetOutput()); + m_ClassStatsFilter->SetInputLabelImage(m_RasterizeClassFilter->GetOutput()); + m_ClassStatsFilter->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram")); + AddProcess(m_ClassStatsFilter->GetStreamer(), "Computing number of samples per class"); + m_ClassStatsFilter->Update(); + + // Remove the no-data entries + StatsFilterType::LabelPopulationMapType fidMap = m_FIDStatsFilter->GetLabelPopulationMap(); + StatsFilterType::LabelPopulationMapType classMap = m_ClassStatsFilter->GetLabelPopulationMap(); + fidMap.erase(intNoData); + classMap.erase(intNoData); + + m_StatWriter = StatWriterType::New(); + m_StatWriter->SetFileName(this->GetParameterString("out")); + m_StatWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerClass", classMap); + m_StatWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerVector", fidMap); + m_StatWriter->Update(); } - void DoUpdateParameters() override + void + DoUpdateParameters() { - if ( HasValue("vec") ) - { - std::string vectorFile = GetParameterString("vec"); - ogr::DataSource::Pointer ogrDS = - ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read); - ogr::Layer layer = ogrDS->GetLayer(0); - ogr::Feature feature = layer.ogr().GetNextFeature(); + if (HasValue("vec")) + { + std::string vectorFile = GetParameterString("vec"); + ogr::DataSource::Pointer ogrDS = ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read); + ogr::Layer layer = ogrDS->GetLayer(0); + ogr::Feature feature = layer.ogr().GetNextFeature(); ClearChoices("field"); - - for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++) - { + + for (int iField = 0; iField < feature.ogr().GetFieldCount(); iField++) + { std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef(); key = item; - std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum); + std::string::iterator end = std::remove_if(key.begin(), key.end(), IsNotAlphaNum); std::transform(key.begin(), end, key.begin(), tolower); - + OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType(); - - if(fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64) - { - std::string tmpKey="field."+key.substr(0, end - key.begin()); - AddChoice(tmpKey,item); - } + + if (fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64) + { + std::string tmpKey = "field." + key.substr(0, end - key.begin()); + AddChoice(tmpKey, item); } } + } - // Check that the extension of the output parameter is XML (mandatory for - // StatisticsXMLFileWriter) - // Check it here to trigger the error before polygons analysis - - if ( HasValue("out") ) - { - // Store filename extension - // Check that the right extension is given : expected .xml - const std::string extension = itksys::SystemTools::GetFilenameLastExtension(this->GetParameterString("out")); - - if (itksys::SystemTools::LowerCase(extension) != ".xml") - { - otbAppLogFATAL( << extension << " is a wrong extension for parameter \"out\": Expected .xml" ); - } - } - } - - void DoExecute() override - { - - // Filters - VectorDataReprojFilterType::Pointer m_VectorDataReprojectionFilter; - RasterizeFilterType::Pointer m_RasterizeFIDFilter; - RasterizeFilterType::Pointer m_RasterizeClassFilter; - NoDataMaskFilterType::Pointer m_NoDataFilter; - CastFilterType::Pointer m_NoDataCastFilter; - StatsFilterType::Pointer m_FIDStatsFilter; - StatsFilterType::Pointer m_ClassStatsFilter; - - // Retrieve the field name - std::vector<int> selectedCFieldIdx = GetSelectedItems("field"); + // Check that the extension of the output parameter is XML (mandatory for + // StatisticsXMLFileWriter) + // Check it here to trigger the error before polygons analysis - if(selectedCFieldIdx.empty()) + if (HasValue("out")) { - otbAppLogFATAL(<<"No field has been selected for data labelling!"); - } - - std::vector<std::string> cFieldNames = GetChoiceNames("field"); - std::string fieldName = cFieldNames[selectedCFieldIdx.front()]; - - otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this,"elev"); - - // Get inputs - FloatVectorImageType::Pointer xs = GetParameterImage("in"); - VectorDataType* shp = GetParameterVectorData("vec"); - - // Reproject vector data - m_VectorDataReprojectionFilter = VectorDataReprojFilterType::New(); - m_VectorDataReprojectionFilter->SetInputVectorData(shp); - m_VectorDataReprojectionFilter->SetInputImage(xs); - m_VectorDataReprojectionFilter->Update(); - - // Internal no-data value - const LabelImageType::ValueType intNoData = - itk::NumericTraits<LabelImageType::ValueType>::max(); - - // Rasterize vector data (geometry ID) - m_RasterizeFIDFilter = RasterizeFilterType::New(); - m_RasterizeFIDFilter->AddVectorData(m_VectorDataReprojectionFilter->GetOutput()); - m_RasterizeFIDFilter->SetOutputOrigin(xs->GetOrigin()); - m_RasterizeFIDFilter->SetOutputSpacing(xs->GetSignedSpacing()); - m_RasterizeFIDFilter->SetOutputSize(xs->GetLargestPossibleRegion().GetSize()); - m_RasterizeFIDFilter->SetBurnAttribute("________"); // Trick to get the polygon ID - m_RasterizeFIDFilter->SetGlobalWarningDisplay(false); - m_RasterizeFIDFilter->SetOutputProjectionRef(xs->GetProjectionRef()); - m_RasterizeFIDFilter->SetBackgroundValue(intNoData); - m_RasterizeFIDFilter->SetDefaultBurnValue(0); - - // Rasterize vector data (geometry class) - m_RasterizeClassFilter = RasterizeFilterType::New(); - m_RasterizeClassFilter->AddVectorData(m_VectorDataReprojectionFilter->GetOutput()); - m_RasterizeClassFilter->SetOutputOrigin(xs->GetOrigin()); - m_RasterizeClassFilter->SetOutputSpacing(xs->GetSignedSpacing()); - m_RasterizeClassFilter->SetOutputSize(xs->GetLargestPossibleRegion().GetSize()); - m_RasterizeClassFilter->SetBurnAttribute(fieldName); - m_RasterizeClassFilter->SetOutputProjectionRef(xs->GetProjectionRef()); - m_RasterizeClassFilter->SetBackgroundValue(intNoData); - m_RasterizeClassFilter->SetDefaultBurnValue(0); - - // No data mask - m_NoDataFilter = NoDataMaskFilterType::New(); - m_NoDataFilter->SetInput(xs); - m_NoDataCastFilter = CastFilterType::New(); - m_NoDataCastFilter->SetInput(m_NoDataFilter->GetOutput()); - - // Stats (geometry ID) - m_FIDStatsFilter = StatsFilterType::New(); - m_FIDStatsFilter->SetInput(m_NoDataCastFilter->GetOutput()); - m_FIDStatsFilter->SetInputLabelImage(m_RasterizeFIDFilter->GetOutput()); - m_FIDStatsFilter->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram")); - AddProcess(m_FIDStatsFilter->GetStreamer(), "Computing number of samples per vector"); - m_FIDStatsFilter->Update(); - - // Stats (geometry class) - m_ClassStatsFilter = StatsFilterType::New(); - m_ClassStatsFilter->SetInput(m_NoDataCastFilter->GetOutput()); - m_ClassStatsFilter->SetInputLabelImage(m_RasterizeClassFilter->GetOutput()); - m_ClassStatsFilter->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram")); - AddProcess(m_ClassStatsFilter->GetStreamer(), "Computing number of samples per class"); - m_ClassStatsFilter->Update(); - - // Remove the no-data entries - StatsFilterType::LabelPopulationMapType fidMap = m_FIDStatsFilter->GetLabelPopulationMap(); - StatsFilterType::LabelPopulationMapType classMap = m_ClassStatsFilter->GetLabelPopulationMap(); - fidMap.erase(intNoData); - classMap.erase(intNoData); - - StatWriterType::Pointer statWriter = StatWriterType::New(); - statWriter->SetFileName(this->GetParameterString("out")); - statWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerClass",classMap); - statWriter->AddInputMap<StatsFilterType::LabelPopulationMapType>("samplesPerVector",fidMap); - statWriter->Update(); + // Store filename extension + // Check that the right extension is given : expected .xml + const std::string extension = itksys::SystemTools::GetFilenameLastExtension(this->GetParameterString("out")); + if (itksys::SystemTools::LowerCase(extension) != ".xml") + { + otbAppLogFATAL(<< extension << " is a wrong extension for parameter \"out\": Expected .xml"); + } + } } - + +private: + // Filters + VectorDataReprojFilterType::Pointer m_VectorDataReprojectionFilter; + RasterizeFilterType::Pointer m_RasterizeFIDFilter; + RasterizeFilterType::Pointer m_RasterizeClassFilter; + NoDataMaskFilterType::Pointer m_NoDataFilter; + CastFilterType::Pointer m_NoDataCastFilter; + StatsFilterType::Pointer m_FIDStatsFilter; + StatsFilterType::Pointer m_ClassStatsFilter; + StatWriterType::Pointer m_StatWriter; }; } // end of namespace Wrapper diff --git a/app/otbImageClassifierFromDeepFeatures.cxx b/app/otbImageClassifierFromDeepFeatures.cxx index f861da418c6ce4a257e673132fe8fda04fb6bc6f..3760f587aa23d13fdae233d4b79945969efc1eca 100644 --- a/app/otbImageClassifierFromDeepFeatures.cxx +++ b/app/otbImageClassifierFromDeepFeatures.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -33,23 +34,23 @@ class ImageClassifierFromDeepFeatures : public CompositeApplication { public: /** Standard class typedefs. */ - typedef ImageClassifierFromDeepFeatures Self; - typedef Application Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef ImageClassifierFromDeepFeatures Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); itkTypeMacro(ImageClassifierFromDeepFeatures, otb::Wrapper::CompositeApplication); private: - // // Add an input source, which includes: // -an input image list // -an input patchsize (dimensions of samples) // - void AddAnInputImage(int inputNumber = 0) + void + AddAnInputImage(int inputNumber = 0) { inputNumber++; @@ -60,11 +61,11 @@ private: // Populate group ShareParameter(ss_key_group.str(), "tfmodel." + ss_key_group.str(), ss_desc_group.str()); - } - void DoInit() + void + DoInit() { SetName("ImageClassifierFromDeepFeatures"); @@ -81,54 +82,48 @@ private: ClearApplications(); // Add applications - AddApplication("ImageClassifier", "classif", "Images classifier" ); - AddApplication("TensorflowModelServe", "tfmodel", "Serve the TF model" ); + AddApplication("ImageClassifier", "classif", "Images classifier"); + AddApplication("TensorflowModelServe", "tfmodel", "Serve the TF model"); // Model shared parameters AddAnInputImage(); - for (int i = 1; i < tf::GetNumberOfSources() ; i++) + for (int i = 1; i < tf::GetNumberOfSources(); i++) { AddAnInputImage(i); } - ShareParameter("deepmodel", "tfmodel.model", - "Deep net model parameters", "Deep net model parameters"); - ShareParameter("output", "tfmodel.output", - "Deep net outputs parameters", - "Deep net outputs parameters"); - ShareParameter("optim", "tfmodel.optim", - "This group of parameters allows optimization of processing time", - "This group of parameters allows optimization of processing time"); + ShareParameter("deepmodel", "tfmodel.model", "Deep net model parameters", "Deep net model parameters"); + ShareParameter("output", "tfmodel.output", "Deep net outputs parameters", "Deep net outputs parameters"); + ShareParameter("optim", + "tfmodel.optim", + "This group of parameters allows optimization of processing time", + "This group of parameters allows optimization of processing time"); // Classify shared parameters - ShareParameter("model" , "classif.model" , "Model file" , "Model file" ); - ShareParameter("imstat" , "classif.imstat" , "Statistics file" , "Statistics file" ); - ShareParameter("nodatalabel", "classif.nodatalabel", "Label mask value" , "Label mask value" ); - ShareParameter("out" , "classif.out" , "Output image" , "Output image" ); - ShareParameter("confmap" , "classif.confmap" , "Confidence map image", "Confidence map image"); - ShareParameter("ram" , "classif.ram" , "Ram" , "Ram" ); - + ShareParameter("model", "classif.model", "Model file", "Model file"); + ShareParameter("imstat", "classif.imstat", "Statistics file", "Statistics file"); + ShareParameter("nodatalabel", "classif.nodatalabel", "Label mask value", "Label mask value"); + ShareParameter("out", "classif.out", "Output image", "Output image"); + ShareParameter("confmap", "classif.confmap", "Confidence map image", "Confidence map image"); + ShareParameter("ram", "classif.ram", "Ram", "Ram"); } - - void DoUpdateParameters() + void + DoUpdateParameters() { UpdateInternalParameters("classif"); } - void DoExecute() + void + DoExecute() { ExecuteInternal("tfmodel"); - GetInternalApplication("classif")->SetParameterInputImage("in", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); + GetInternalApplication("classif")->SetParameterInputImage( + "in", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); UpdateInternalParameters("classif"); ExecuteInternal("classif"); - } // DOExecute() - - void AfterExecuteAndWriteOutputs() - { - // Nothing to do } }; } // namespace Wrapper } // namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::ImageClassifierFromDeepFeatures ) +OTB_APPLICATION_EXPORT(otb::Wrapper::ImageClassifierFromDeepFeatures) diff --git a/app/otbLabelImageSampleSelection.cxx b/app/otbLabelImageSampleSelection.cxx index 5364453d84f4c99f28bab2c938a0af220db02b6d..f0d2c03debfaae2ea80c8b3c65e90bbb801a26e0 100644 --- a/app/otbLabelImageSampleSelection.cxx +++ b/app/otbLabelImageSampleSelection.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -34,59 +35,62 @@ class LabelImageSampleSelection : public Application { public: /** Standard class typedefs. */ - typedef LabelImageSampleSelection Self; - typedef Application Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef LabelImageSampleSelection Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); itkTypeMacro(LabelImageSampleSelection, Application); /** Vector data typedefs */ - typedef VectorDataType::DataTreeType DataTreeType; - typedef itk::PreOrderTreeIterator<DataTreeType> TreeIteratorType; - typedef VectorDataType::DataNodeType DataNodeType; - typedef DataNodeType::Pointer DataNodePointer; + typedef VectorDataType::DataTreeType DataTreeType; + typedef itk::PreOrderTreeIterator<DataTreeType> TreeIteratorType; + typedef VectorDataType::DataNodeType DataNodeType; + typedef DataNodeType::Pointer DataNodePointer; /** typedefs */ - typedef Int16ImageType LabelImageType; - typedef unsigned int IndexValueType; + typedef Int16ImageType LabelImageType; + typedef unsigned int IndexValueType; - void DoUpdateParameters() - { - } + void + DoUpdateParameters() + {} /* * Display the percentage */ - void ShowProgress(unsigned int count, unsigned int total, unsigned int step = 1000) + void + ShowProgress(unsigned int count, unsigned int total, unsigned int step = 1000) { if (count % step == 0) { - std::cout << std::setprecision(3) << "\r" << (100.0 * count / (float) total) << "% " << std::flush; + std::cout << std::setprecision(3) << "\r" << (100.0 * count / (float)total) << "% " << std::flush; } } - void ShowProgressDone() + void + ShowProgressDone() { std::cout << "\rDone " << std::flush; std::cout << std::endl; } - void DoInit() + void + DoInit() { // Documentation SetName("LabelImageSampleSelection"); SetDescription("This application extracts points from an input label image. " - "This application is like \"SampleSelection\", but uses an input label " - "image, rather than an input vector data."); + "This application is like \"SampleSelection\", but uses an input label " + "image, rather than an input vector data."); SetDocLongDescription("This application produces a vector data containing " - "a set of points centered on the pixels of the input label image. " - "The user can control the number of points. The default strategy consists " - "in producing the same number of points in each class. If one class has a " - "smaller number of points than requested, this one is adjusted."); + "a set of points centered on the pixels of the input label image. " + "The user can control the number of points. The default strategy consists " + "in producing the same number of points in each class. If one class has a " + "smaller number of points than requested, this one is adjusted."); SetDocAuthors("Remi Cresson"); @@ -96,39 +100,41 @@ public: // Strategy AddParameter(ParameterType_Choice, "strategy", "Sampling strategy"); - AddChoice("strategy.constant","Set the same samples counts for all classes"); - SetParameterDescription("strategy.constant","Set the same samples counts for all classes"); + AddChoice("strategy.constant", "Set the same samples counts for all classes"); + SetParameterDescription("strategy.constant", "Set the same samples counts for all classes"); AddParameter(ParameterType_Int, "strategy.constant.nb", "Number of samples for all classes"); SetParameterDescription("strategy.constant.nb", "Number of samples for all classes"); - SetMinimumParameterIntValue("strategy.constant.nb",1); - SetDefaultParameterInt("strategy.constant.nb",1000); + SetMinimumParameterIntValue("strategy.constant.nb", 1); + SetDefaultParameterInt("strategy.constant.nb", 1000); - AddChoice("strategy.total","Set the total number of samples to generate, and use class proportions."); - SetParameterDescription("strategy.total","Set the total number of samples to generate, and use class proportions."); - AddParameter(ParameterType_Int,"strategy.total.v","The number of samples to generate"); - SetParameterDescription("strategy.total.v","The number of samples to generate"); - SetMinimumParameterIntValue("strategy.total.v",1); - SetDefaultParameterInt("strategy.total.v",1000); + AddChoice("strategy.total", "Set the total number of samples to generate, and use class proportions."); + SetParameterDescription("strategy.total", + "Set the total number of samples to generate, and use class proportions."); + AddParameter(ParameterType_Int, "strategy.total.v", "The number of samples to generate"); + SetParameterDescription("strategy.total.v", "The number of samples to generate"); + SetMinimumParameterIntValue("strategy.total.v", 1); + SetDefaultParameterInt("strategy.total.v", 1000); - AddChoice("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled"); - SetParameterDescription("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled"); + AddChoice("strategy.smallest", "Set same number of samples for all classes, with the smallest class fully sampled"); + SetParameterDescription("strategy.smallest", + "Set same number of samples for all classes, with the smallest class fully sampled"); - AddChoice("strategy.all","Take all samples"); - SetParameterDescription("strategy.all","Take all samples"); + AddChoice("strategy.all", "Take all samples"); + SetParameterDescription("strategy.all", "Take all samples"); // Default strategy : smallest - SetParameterString("strategy","constant"); + SetParameterString("strategy", "constant"); // Input no-data value AddParameter(ParameterType_Int, "nodata", "nodata value"); - MandatoryOn ("nodata"); - SetDefaultParameterInt ("nodata", -1); + MandatoryOn("nodata"); + SetDefaultParameterInt("nodata", -1); // Padding AddParameter(ParameterType_Int, "pad", "padding, in pixels"); - SetDefaultParameterInt ("pad", 0); - MandatoryOff ("pad"); + SetDefaultParameterInt("pad", 0); + MandatoryOff("pad"); // Output points AddParameter(ParameterType_OutputVectorData, "outvec", "output set of points"); @@ -138,19 +144,20 @@ public: SetDocExampleParameterValue("outvec", "terrain_truth_points_sel.sqlite"); AddRAMParameter(); - } - void DoExecute() + void + DoExecute() { // Count the number of pixels in each class const LabelImageType::InternalPixelType MAX_NB_OF_CLASSES = - itk::NumericTraits<LabelImageType::InternalPixelType>::max();; + itk::NumericTraits<LabelImageType::InternalPixelType>::max(); + ; LabelImageType::InternalPixelType class_begin = MAX_NB_OF_CLASSES; LabelImageType::InternalPixelType class_end = 0; - vnl_vector<IndexValueType> tmp_number_of_samples(MAX_NB_OF_CLASSES, 0); + vnl_vector<IndexValueType> tmp_number_of_samples(MAX_NB_OF_CLASSES, 0); otbAppLogINFO("Computing number of pixels in each class"); @@ -160,10 +167,10 @@ public: m_StreamingManager->SetAvailableRAMInMB(GetParameterInt("ram")); // We pad the image, if this is requested by the user - LabelImageType::Pointer inputImage = GetParameterInt16Image("inref"); + LabelImageType::Pointer inputImage = GetParameterInt16Image("inref"); LabelImageType::RegionType entireRegion = inputImage->GetLargestPossibleRegion(); entireRegion.ShrinkByRadius(GetParameterInt("pad")); - m_StreamingManager->PrepareStreaming(inputImage, entireRegion ); + m_StreamingManager->PrepareStreaming(inputImage, entireRegion); // Get nodata value const LabelImageType::InternalPixelType nodata = GetParameterInt("nodata"); @@ -174,7 +181,7 @@ public: { LabelImageType::RegionType streamRegion = m_StreamingManager->GetSplit(m_CurrentDivision); tf::PropagateRequestedRegion<LabelImageType>(inputImage, streamRegion); - itk::ImageRegionConstIterator<LabelImageType> inIt (inputImage, streamRegion); + itk::ImageRegionConstIterator<LabelImageType> inIt(inputImage, streamRegion); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { LabelImageType::InternalPixelType pixVal = inIt.Get(); @@ -203,14 +210,13 @@ public: // Number of samples in each class (target) vnl_vector<IndexValueType> target_number_of_samples(number_of_classes, 0); - otbAppLogINFO( "Number of classes: " << number_of_classes << - " starting from " << class_begin << - " to " << class_end << " (no-data is " << nodata << ")"); - otbAppLogINFO( "Number of pixels in each class: " << number_of_samples ); + otbAppLogINFO("Number of classes: " << number_of_classes << " starting from " << class_begin << " to " << class_end + << " (no-data is " << nodata << ")"); + otbAppLogINFO("Number of pixels in each class: " << number_of_samples); // Check the smallest number of samples amongst classes IndexValueType min_elem_in_class = itk::NumericTraits<IndexValueType>::max(); - for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) + for (LabelImageType::InternalPixelType classIdx = 0; classIdx < number_of_classes; classIdx++) min_elem_in_class = std::min(min_elem_in_class, number_of_samples[classIdx]); // If one class is empty, throw an error @@ -225,79 +231,73 @@ public: // Compute the sampling step for each classes, depending on the chosen strategy switch (this->GetParameterInt("strategy")) { - // constant - case 0: - { - // Set the target number of samples in each class - target_number_of_samples.fill(GetParameterInt("strategy.constant.nb")); - - // re adjust the number of samples to select in each class - if (min_elem_in_class < target_number_of_samples[0]) - { - otbAppLogWARNING("Smallest class has " << min_elem_in_class << - " samples but a number of " << target_number_of_samples[0] << - " is given. Using " << min_elem_in_class); - target_number_of_samples.fill( min_elem_in_class ); - } + // constant + case 0: { + // Set the target number of samples in each class + target_number_of_samples.fill(GetParameterInt("strategy.constant.nb")); - // Compute the sampling step - for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) - step_for_class[classIdx] = number_of_samples[classIdx] / target_number_of_samples[classIdx]; - } - break; + // re adjust the number of samples to select in each class + if (min_elem_in_class < target_number_of_samples[0]) + { + otbAppLogWARNING("Smallest class has " << min_elem_in_class << " samples but a number of " + << target_number_of_samples[0] << " is given. Using " + << min_elem_in_class); + target_number_of_samples.fill(min_elem_in_class); + } - // total - case 1: - { - // Compute the sampling step - IndexValueType step = number_of_samples.sum() / this->GetParameterInt("strategy.total.v"); - if (step == 0) - { - otbAppLogWARNING("The number of samples available is smaller than the required number of samples. " << - "Setting sampling step to 1."); - step = 1; + // Compute the sampling step + for (LabelImageType::InternalPixelType classIdx = 0; classIdx < number_of_classes; classIdx++) + step_for_class[classIdx] = number_of_samples[classIdx] / target_number_of_samples[classIdx]; } - step_for_class.fill(step); - - // Compute the target number of samples - for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) - target_number_of_samples[classIdx] = number_of_samples[classIdx] / step; + break; - } - break; + // total + case 1: { + // Compute the sampling step + IndexValueType step = number_of_samples.sum() / this->GetParameterInt("strategy.total.v"); + if (step == 0) + { + otbAppLogWARNING("The number of samples available is smaller than the required number of samples. " + << "Setting sampling step to 1."); + step = 1; + } + step_for_class.fill(step); - // smallest - case 2: - { - // Set the target number of samples to the smallest class - target_number_of_samples.fill( min_elem_in_class ); + // Compute the target number of samples + for (LabelImageType::InternalPixelType classIdx = 0; classIdx < number_of_classes; classIdx++) + target_number_of_samples[classIdx] = number_of_samples[classIdx] / step; + } + break; - // Compute the sampling step - for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) - step_for_class[classIdx] = number_of_samples[classIdx] / target_number_of_samples[classIdx]; + // smallest + case 2: { + // Set the target number of samples to the smallest class + target_number_of_samples.fill(min_elem_in_class); - } - break; + // Compute the sampling step + for (LabelImageType::InternalPixelType classIdx = 0; classIdx < number_of_classes; classIdx++) + step_for_class[classIdx] = number_of_samples[classIdx] / target_number_of_samples[classIdx]; + } + break; - // All - case 3: - { - // Easy - step_for_class.fill(1); - target_number_of_samples = number_of_samples; - } - break; - default: - otbAppLogFATAL("Strategy mode unknown :"<<this->GetParameterString("strategy")); + // All + case 3: { + // Easy + step_for_class.fill(1); + target_number_of_samples = number_of_samples; + } break; + default: + otbAppLogFATAL("Strategy mode unknown :" << this->GetParameterString("strategy")); + break; } // Print quick summary otbAppLogINFO("Sampling summary:"); otbAppLogINFO("\tClass\tStep\tTot"); - for (LabelImageType::InternalPixelType i = 0 ; i < number_of_classes ; i++) + for (LabelImageType::InternalPixelType i = 0; i < number_of_classes; i++) { - vnl_vector<int> tmp (3,0); + vnl_vector<int> tmp(3, 0); tmp[0] = i + class_begin; tmp[1] = step_for_class[i]; tmp[2] = target_number_of_samples[i]; @@ -308,8 +308,8 @@ public: // TODO: how to pre-allocate the datatree? m_OutVectorData = VectorDataType::New(); DataTreeType::Pointer tree = m_OutVectorData->GetDataTree(); - DataNodePointer root = tree->GetRoot()->Get(); - DataNodePointer document = DataNodeType::New(); + DataNodePointer root = tree->GetRoot()->Get(); + DataNodePointer document = DataNodeType::New(); document->SetNodeType(DOCUMENT); tree->Add(document, root); @@ -321,15 +321,15 @@ public: // Second iteration, to prepare the samples vnl_vector<IndexValueType> sampledCount(number_of_classes, 0); vnl_vector<IndexValueType> iteratorCount(number_of_classes, 0); - IndexValueType n_tot = 0; - const IndexValueType target_n_tot = target_number_of_samples.sum(); + IndexValueType n_tot = 0; + const IndexValueType target_n_tot = target_number_of_samples.sum(); for (int m_CurrentDivision = 0; m_CurrentDivision < m_NumberOfDivisions; m_CurrentDivision++) { LabelImageType::RegionType streamRegion = m_StreamingManager->GetSplit(m_CurrentDivision); tf::PropagateRequestedRegion<LabelImageType>(inputImage, streamRegion); - itk::ImageRegionConstIterator<LabelImageType> inIt (inputImage, streamRegion); + itk::ImageRegionConstIterator<LabelImageType> inIt(inputImage, streamRegion); - for (inIt.GoToBegin() ; !inIt.IsAtEnd() ; ++inIt) + for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { LabelImageType::InternalPixelType classVal = inIt.Get(); @@ -341,7 +341,7 @@ public: iteratorCount[classVal]++; // Every Xi samples (Xi is the step for class i) - if (iteratorCount[classVal] % ((int) step_for_class[classVal]) == 0 && + if (iteratorCount[classVal] % ((int)step_for_class[classVal]) == 0 && sampledCount[classVal] < target_number_of_samples[classVal]) { // Add this sample @@ -365,15 +365,14 @@ public: } // sample this one } } // next pixel - } // next streaming region + } // next streaming region ShowProgressDone(); - otbAppLogINFO( "Number of samples in each class: " << sampledCount ); + otbAppLogINFO("Number of samples in each class: " << sampledCount); - otbAppLogINFO( "Writing output vector data"); + otbAppLogINFO("Writing output vector data"); SetParameterOutputVectorData("outvec", m_OutVectorData); - } private: @@ -381,7 +380,7 @@ private: }; // end of class -} // end namespace wrapper +} // namespace Wrapper } // end namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::LabelImageSampleSelection ) +OTB_APPLICATION_EXPORT(otb::Wrapper::LabelImageSampleSelection) diff --git a/app/otbPatchesExtraction.cxx b/app/otbPatchesExtraction.cxx index df13cbc3353e7cfaafec13126077d4c6c8def571..bcea8b906f85314ed455f26575cf593fa718366c 100644 --- a/app/otbPatchesExtraction.cxx +++ b/app/otbPatchesExtraction.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -32,10 +33,10 @@ class PatchesExtraction : public Application { public: /** Standard class typedefs. */ - typedef PatchesExtraction Self; - typedef Application Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef PatchesExtraction Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); @@ -45,23 +46,21 @@ public: typedef otb::TensorflowSampler<FloatVectorImageType, VectorDataType> SamplerType; /** Typedefs for image concatenation */ - typedef TensorflowSource<FloatVectorImageType> TFSourceType; + typedef TensorflowSource<FloatVectorImageType> TFSourceType; // // Store stuff related to one source // struct SourceBundle { - TFSourceType m_ImageSource; // Image source - FloatVectorImageType::SizeType m_PatchSize; // Patch size + TFSourceType m_ImageSource; // Image source + FloatVectorImageType::SizeType m_PatchSize; // Patch size - unsigned int m_NumberOfElements; // Number of output samples - - std::string m_KeyIn; // Key of input image list - std::string m_KeyOut; // Key of output samples image - std::string m_KeyPszX; // Key for samples sizes X - std::string m_KeyPszY; // Key for samples sizes Y - std::string m_KeyNoData; // Key for no-data value + std::string m_KeyIn; // Key of input image list + std::string m_KeyOut; // Key of output samples image + std::string m_KeyPszX; // Key for samples sizes X + std::string m_KeyPszY; // Key for samples sizes Y + std::string m_KeyNoData; // Key for no-data value FloatVectorImageType::InternalPixelType m_NoDataValue; // No data value }; @@ -73,56 +72,57 @@ public: // -an output image (samples) // -an input patchsize (dimensions of samples) // - void AddAnInputImage() + void + AddAnInputImage() { // Number of source unsigned int inputNumber = m_Bundles.size() + 1; // Create keys and descriptions - std::stringstream ss_group_key, ss_desc_group, ss_key_in, ss_key_out, ss_desc_in, - ss_desc_out, ss_key_dims_x, ss_desc_dims_x, ss_key_dims_y, ss_desc_dims_y, ss_key_nodata, ss_desc_nodata; - ss_group_key << "source" << inputNumber; - ss_desc_group << "Parameters for source " << inputNumber; - ss_key_out << ss_group_key.str() << ".out"; - ss_desc_out << "Output patches for image " << inputNumber; - ss_key_in << ss_group_key.str() << ".il"; - ss_desc_in << "Input image(s) " << inputNumber; - ss_key_dims_x << ss_group_key.str() << ".patchsizex"; - ss_desc_dims_x << "X patch size for image " << inputNumber; - ss_key_dims_y << ss_group_key.str() << ".patchsizey"; - ss_desc_dims_y << "Y patch size for image " << inputNumber; - ss_key_nodata << ss_group_key.str() << ".nodata"; - ss_desc_nodata << "No-data value for image " << inputNumber << "(used only if \"usenodata\" is on)"; + std::stringstream ss_group_key, ss_desc_group, ss_key_in, ss_key_out, ss_desc_in, ss_desc_out, ss_key_dims_x, + ss_desc_dims_x, ss_key_dims_y, ss_desc_dims_y, ss_key_nodata, ss_desc_nodata; + ss_group_key << "source" << inputNumber; + ss_desc_group << "Parameters for source " << inputNumber; + ss_key_out << ss_group_key.str() << ".out"; + ss_desc_out << "Output patches for image " << inputNumber; + ss_key_in << ss_group_key.str() << ".il"; + ss_desc_in << "Input image(s) " << inputNumber; + ss_key_dims_x << ss_group_key.str() << ".patchsizex"; + ss_desc_dims_x << "X patch size for image " << inputNumber; + ss_key_dims_y << ss_group_key.str() << ".patchsizey"; + ss_desc_dims_y << "Y patch size for image " << inputNumber; + ss_key_nodata << ss_group_key.str() << ".nodata"; + ss_desc_nodata << "No-data value for image " << inputNumber << "(used only if \"usenodata\" is on)"; // Populate group - AddParameter(ParameterType_Group, ss_group_key.str(), ss_desc_group.str()); - AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str() ); - AddParameter(ParameterType_OutputImage, ss_key_out.str(), ss_desc_out.str()); - AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); - SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); - AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); - SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); - AddParameter(ParameterType_Float, ss_key_nodata.str(), ss_desc_nodata.str()); - SetDefaultParameterFloat (ss_key_nodata.str(), 0); + AddParameter(ParameterType_Group, ss_group_key.str(), ss_desc_group.str()); + AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str()); + AddParameter(ParameterType_OutputImage, ss_key_out.str(), ss_desc_out.str()); + AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); + SetMinimumParameterIntValue(ss_key_dims_x.str(), 1); + AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); + SetMinimumParameterIntValue(ss_key_dims_y.str(), 1); + AddParameter(ParameterType_Float, ss_key_nodata.str(), ss_desc_nodata.str()); + SetDefaultParameterFloat(ss_key_nodata.str(), 0); // Add a new bundle SourceBundle bundle; - bundle.m_KeyIn = ss_key_in.str(); - bundle.m_KeyOut = ss_key_out.str(); + bundle.m_KeyIn = ss_key_in.str(); + bundle.m_KeyOut = ss_key_out.str(); bundle.m_KeyPszX = ss_key_dims_x.str(); bundle.m_KeyPszY = ss_key_dims_y.str(); bundle.m_KeyNoData = ss_key_nodata.str(); m_Bundles.push_back(bundle); - } // // Prepare bundles from the number of points // - void PrepareInputs() + void + PrepareInputs() { - for (auto& bundle: m_Bundles) + for (auto & bundle : m_Bundles) { // Create a stack of input images FloatVectorImageListType::Pointer list = GetParameterImageList(bundle.m_KeyIn); @@ -137,30 +137,31 @@ public: } } - void DoUpdateParameters() - { - } - - void DoInit() + void + DoInit() { // Documentation SetName("PatchesExtraction"); SetDescription("This application extracts patches in multiple input images. Change " - "the " + tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of " - "sources."); - SetDocLongDescription("The application takes an input vector layer which is a set of " - "points, typically the output of the \"SampleSelection\" or the \"LabelImageSampleSelection\" " - "application to sample patches in the input images (samples are centered on the points). " - "A \"source\" parameters group is composed of (i) an input image list (can be " - "one image e.g. high res. image, or multiple e.g. time series), (ii) the size " - "of the patches to sample, and (iii) the output images of patches which will " - "be generated at the end of the process. The example below show how to " - "set the samples sizes. For a SPOT6 image for instance, the patch size can " - "be 64x64 and for an input Sentinel-2 time series the patch size could be " - "1x1. Note that if a dimension size is not defined, the largest one will " - "be used (i.e. input image dimensions. The number of input sources can be changed " - "at runtime by setting the system environment variable " + tf::ENV_VAR_NAME_NSOURCES); + "the " + + tf::ENV_VAR_NAME_NSOURCES + + " environment variable to set the number of " + "sources."); + SetDocLongDescription( + "The application takes an input vector layer which is a set of " + "points, typically the output of the \"SampleSelection\" or the \"LabelImageSampleSelection\" " + "application to sample patches in the input images (samples are centered on the points). " + "A \"source\" parameters group is composed of (i) an input image list (can be " + "one image e.g. high res. image, or multiple e.g. time series), (ii) the size " + "of the patches to sample, and (iii) the output images of patches which will " + "be generated at the end of the process. The example below show how to " + "set the samples sizes. For a SPOT6 image for instance, the patch size can " + "be 64x64 and for an input Sentinel-2 time series the patch size could be " + "1x1. Note that if a dimension size is not defined, the largest one will " + "be used (i.e. input image dimensions. The number of input sources can be changed " + "at runtime by setting the system environment variable " + + tf::ENV_VAR_NAME_NSOURCES); SetDocAuthors("Remi Cresson"); @@ -168,36 +169,37 @@ public: // Input/output images AddAnInputImage(); - for (int i = 1; i < tf::GetNumberOfSources() ; i++) + for (int i = 1; i < tf::GetNumberOfSources(); i++) AddAnInputImage(); // Input vector data - AddParameter(ParameterType_InputVectorData, "vec", "Positions of the samples (must be in the same projection as input image)"); + AddParameter( + ParameterType_InputVectorData, "vec", "Positions of the samples (must be in the same projection as input image)"); // No data parameters AddParameter(ParameterType_Bool, "usenodata", "Reject samples that have no-data value"); - MandatoryOff ("usenodata"); + MandatoryOff("usenodata"); // Output label AddParameter(ParameterType_OutputImage, "outlabels", "output labels"); - SetDefaultOutputPixelType ("outlabels", ImagePixelType_uint8); - MandatoryOff ("outlabels"); + SetDefaultOutputPixelType("outlabels", ImagePixelType_uint8); + MandatoryOff("outlabels"); // Class field AddParameter(ParameterType_String, "field", "field of class in the vector data"); // Examples values - SetDocExampleParameterValue("vec", "points.sqlite"); - SetDocExampleParameterValue("source1.il", "$s2_list"); + SetDocExampleParameterValue("vec", "points.sqlite"); + SetDocExampleParameterValue("source1.il", "$s2_list"); SetDocExampleParameterValue("source1.patchsizex", "16"); SetDocExampleParameterValue("source1.patchsizey", "16"); - SetDocExampleParameterValue("field", "class"); - SetDocExampleParameterValue("source1.out", "outpatches_16x16.tif"); - SetDocExampleParameterValue("outlabels", "outlabels.tif"); - + SetDocExampleParameterValue("field", "class"); + SetDocExampleParameterValue("source1.out", "outpatches_16x16.tif"); + SetDocExampleParameterValue("outlabels", "outlabels.tif"); } - void DoExecute() + void + DoExecute() { PrepareInputs(); @@ -206,12 +208,12 @@ public: SamplerType::Pointer sampler = SamplerType::New(); sampler->SetInputVectorData(GetParameterVectorData("vec")); sampler->SetField(GetParameterAsString("field")); - if (GetParameterInt("usenodata")==1) - { + if (GetParameterInt("usenodata") == 1) + { otbAppLogINFO("Rejecting samples that have at least one no-data value"); sampler->SetRejectPatchesWithNodata(true); - } - for (auto& bundle: m_Bundles) + } + for (auto & bundle : m_Bundles) { sampler->PushBackInputWithPatchSize(bundle.m_ImageSource.Get(), bundle.m_PatchSize, bundle.m_NoDataValue); } @@ -225,7 +227,7 @@ public: otbAppLogINFO("Number of samples rejected : " << sampler->GetNumberOfRejectedSamples()); // Save patches image - for (unsigned int i = 0 ; i < m_Bundles.size() ; i++) + for (unsigned int i = 0; i < m_Bundles.size(); i++) { SetParameterOutputImage(m_Bundles[i].m_KeyOut, sampler->GetOutputPatchImages()[i]); } @@ -236,14 +238,19 @@ public: { SetParameterOutputImage("outlabels", sampler->GetOutputLabelImage()); } - } + + + void + DoUpdateParameters() + {} + private: std::vector<SourceBundle> m_Bundles; }; // end of class -} // end namespace wrapper +} // namespace Wrapper } // end namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::PatchesExtraction ) +OTB_APPLICATION_EXPORT(otb::Wrapper::PatchesExtraction) diff --git a/app/otbPatchesSelection.cxx b/app/otbPatchesSelection.cxx index 170fc9b6e1f2c958abeef47c202d41ea170751d7..186a58287036f1f36502ce3685b4c60515a3d451 100644 --- a/app/otbPatchesSelection.cxx +++ b/app/otbPatchesSelection.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -31,16 +32,17 @@ #include "itkImageRegionConstIteratorWithOnlyIndex.h" // Functor to retrieve nodata -template<class TPixel, class OutputPixel> +template <class TPixel, class OutputPixel> class IsNoData { public: - IsNoData(){} - ~IsNoData(){} + IsNoData() {} + ~IsNoData() {} - inline OutputPixel operator()( const TPixel & A ) const + inline OutputPixel + operator()(const TPixel & A) const { - for (unsigned int band = 0 ; band < A.Size() ; band++) + for (unsigned int band = 0; band < A.Size(); band++) { if (A[band] != m_NoDataValue) return 1; @@ -48,7 +50,8 @@ public: return 0; } - void SetNoDataValue(typename TPixel::ValueType value) + void + SetNoDataValue(typename TPixel::ValueType value) { m_NoDataValue = value; } @@ -67,52 +70,48 @@ class PatchesSelection : public Application { public: /** Standard class typedefs. */ - typedef PatchesSelection Self; - typedef Application Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef PatchesSelection Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); itkTypeMacro(PatchesSelection, Application); /** Vector data typedefs */ - typedef VectorDataType::DataTreeType DataTreeType; - typedef itk::PreOrderTreeIterator<DataTreeType> TreeIteratorType; - typedef VectorDataType::DataNodeType DataNodeType; - typedef DataNodeType::Pointer DataNodePointer; - typedef DataNodeType::PointType DataNodePointType; + typedef VectorDataType::DataTreeType DataTreeType; + typedef itk::PreOrderTreeIterator<DataTreeType> TreeIteratorType; + typedef VectorDataType::DataNodeType DataNodeType; + typedef DataNodeType::Pointer DataNodePointer; + typedef DataNodeType::PointType DataNodePointType; /** typedefs */ - typedef IsNoData<FloatVectorImageType::PixelType, UInt8ImageType::PixelType > IsNoDataFunctorType; + typedef IsNoData<FloatVectorImageType::PixelType, UInt8ImageType::PixelType> IsNoDataFunctorType; typedef itk::UnaryFunctorImageFilter<FloatVectorImageType, UInt8ImageType, IsNoDataFunctorType> IsNoDataFilterType; - typedef itk::FlatStructuringElement<2> StructuringType; - typedef StructuringType::RadiusType RadiusType; + typedef itk::FlatStructuringElement<2> StructuringType; + typedef StructuringType::RadiusType RadiusType; typedef itk::BinaryErodeImageFilter<UInt8ImageType, UInt8ImageType, StructuringType> MorphoFilterType; - typedef otb::StreamingResampleImageFilter<UInt8ImageType,UInt8ImageType> PadFilterType; - typedef itk::NearestNeighborInterpolateImageFunction<UInt8ImageType> NNInterpolatorType; + typedef otb::StreamingResampleImageFilter<UInt8ImageType, UInt8ImageType> PadFilterType; + typedef itk::NearestNeighborInterpolateImageFunction<UInt8ImageType> NNInterpolatorType; typedef tf::Distribution<UInt8ImageType> DistributionType; typedef itk::MaskImageFilter<UInt8ImageType, UInt8ImageType, UInt8ImageType> MaskImageFilterType; - void DoUpdateParameters() - { - } - - - void DoInit() + void + DoInit() { // Documentation SetName("PatchesSelection"); SetDescription("This application generate points sampled at regular interval over " - "the input image region. The grid size and spacing can be configured."); + "the input image region. The grid size and spacing can be configured."); SetDocLongDescription("This application produces a vector data containing " - "a set of points centered on the patches lying in the valid regions of the input image. "); + "a set of points centered on the patches lying in the valid regions of the input image. "); SetDocAuthors("Remi Cresson"); @@ -123,180 +122,196 @@ public: // Input no-data value AddParameter(ParameterType_Float, "nodata", "nodata value"); - MandatoryOn ("nodata"); - SetDefaultParameterFloat ("nodata", 0); - AddParameter(ParameterType_Bool, "nocheck", "If on, no check on the validity of patches is performed"); - MandatoryOff ("nocheck"); + MandatoryOn("nodata"); + SetDefaultParameterFloat("nodata", 0); + AddParameter(ParameterType_Bool, "nocheck", "If on, no check on the validity of patches is performed"); + MandatoryOff("nocheck"); // Grid AddParameter(ParameterType_Group, "grid", "grid settings"); AddParameter(ParameterType_Int, "grid.step", "step between patches"); - SetMinimumParameterIntValue ("grid.step", 1); + SetMinimumParameterIntValue("grid.step", 1); AddParameter(ParameterType_Int, "grid.psize", "patches size"); - SetMinimumParameterIntValue ("grid.psize", 1); + SetMinimumParameterIntValue("grid.psize", 1); AddParameter(ParameterType_Int, "grid.offsetx", "offset of the grid (x axis)"); - SetDefaultParameterInt ("grid.offsetx", 0); + SetDefaultParameterInt("grid.offsetx", 0); AddParameter(ParameterType_Int, "grid.offsety", "offset of the grid (y axis)"); - SetDefaultParameterInt ("grid.offsety", 0); + SetDefaultParameterInt("grid.offsety", 0); // Strategy AddParameter(ParameterType_Choice, "strategy", "Selection strategy for validation/training patches"); AddChoice("strategy.chessboard", "fifty fifty, like a chess board"); AddChoice("strategy.balanced", "you can chose the degree of spatial randomness vs class balance"); - AddParameter(ParameterType_Float, "strategy.balanced.sp", "Spatial proportion: between 0 and 1, " - "indicating the amount of randomly sampled data in space"); - SetMinimumParameterFloatValue ("strategy.balanced.sp", 0); - SetMaximumParameterFloatValue ("strategy.balanced.sp", 1); - SetDefaultParameterFloat ("strategy.balanced.sp", 0.25); - AddParameter(ParameterType_Int, "strategy.balanced.nclasses", "Number of classes"); - SetMinimumParameterIntValue ("strategy.balanced.nclasses", 2); - MandatoryOn ("strategy.balanced.nclasses"); + AddParameter(ParameterType_Float, + "strategy.balanced.sp", + "Spatial proportion: between 0 and 1, " + "indicating the amount of randomly sampled data in space"); + SetMinimumParameterFloatValue("strategy.balanced.sp", 0); + SetMaximumParameterFloatValue("strategy.balanced.sp", 1); + SetDefaultParameterFloat("strategy.balanced.sp", 0.25); + AddParameter(ParameterType_Int, "strategy.balanced.nclasses", "Number of classes"); + SetMinimumParameterIntValue("strategy.balanced.nclasses", 2); + MandatoryOn("strategy.balanced.nclasses"); AddParameter(ParameterType_InputImage, "strategy.balanced.labelimage", "input label image"); - MandatoryOn ("strategy.balanced.labelimage"); + MandatoryOn("strategy.balanced.labelimage"); // Output points AddParameter(ParameterType_OutputVectorData, "outtrain", "output set of points (training)"); AddParameter(ParameterType_OutputVectorData, "outvalid", "output set of points (validation)"); AddRAMParameter(); - } class SampleBundle { public: - SampleBundle(){} - SampleBundle(unsigned int nClasses){ - dist = DistributionType(nClasses); - id = 0; - (void) point; - black = true; - (void) index; - } - ~SampleBundle(){} - - SampleBundle(const SampleBundle & other){ - dist = other.GetDistribution(); - id = other.GetSampleID(); - point = other.GetPosition(); - black = other.GetBlack(); - index = other.GetIndex(); + SampleBundle() {} + explicit SampleBundle(unsigned int nClasses) + : dist(DistributionType(nClasses)) + , id(0) + , black(true) + { + (void)point; + (void)index; } - - DistributionType GetDistribution() const + ~SampleBundle() {} + + SampleBundle(const SampleBundle & other) + : dist(other.GetDistribution()) + , id(other.GetSampleID()) + , point(other.GetPosition()) + , black(other.GetBlack()) + , index(other.GetIndex()) + {} + + DistributionType + GetDistribution() const { return dist; } - DistributionType& GetModifiableDistribution() + DistributionType & + GetModifiableDistribution() { return dist; } - unsigned int GetSampleID() const + unsigned int + GetSampleID() const { return id; } - unsigned int& GetModifiableSampleID() + unsigned int & + GetModifiableSampleID() { return id; } - DataNodePointType GetPosition() const + DataNodePointType + GetPosition() const { return point; } - DataNodePointType& GetModifiablePosition() + DataNodePointType & + GetModifiablePosition() { return point; } - bool& GetModifiableBlack() + bool & + GetModifiableBlack() { return black; } - bool GetBlack() const + bool + GetBlack() const { return black; } - UInt8ImageType::IndexType& GetModifiableIndex() + UInt8ImageType::IndexType & + GetModifiableIndex() { return index; } - UInt8ImageType::IndexType GetIndex() const + UInt8ImageType::IndexType + GetIndex() const { return index; } private: - - DistributionType dist; - unsigned int id; - DataNodePointType point; - bool black; + DistributionType dist; + unsigned int id; + DataNodePointType point; + bool black; UInt8ImageType::IndexType index; }; /* * Apply the given function at each sampling location, checking if the patch is valid or not */ - template<typename TLambda> - void Apply(TLambda lambda) + template <typename TLambda> + void + Apply(TLambda lambda) { int userOffX = GetParameterInt("grid.offsetx"); int userOffY = GetParameterInt("grid.offsety"); + // Tell if the patch size is odd or even + const bool isEven = GetParameterInt("grid.psize") % 2 == 0; + otbAppLogINFO("Patch size is even: " << isEven); + // Explicit streaming over the morphed mask, based on the RAM parameter typedef otb::RAMDrivenStrippedStreamingManager<UInt8ImageType> StreamingManagerType; - StreamingManagerType::Pointer m_StreamingManager = StreamingManagerType::New(); + StreamingManagerType::Pointer m_StreamingManager = StreamingManagerType::New(); m_StreamingManager->SetAvailableRAMInMB(GetParameterInt("ram")); UInt8ImageType::Pointer inputImage; - bool readInput = true; - if (GetParameterInt("nocheck")==1) - { + bool readInput = true; + if (GetParameterInt("nocheck") == 1) + { otbAppLogINFO("\"nocheck\" mode is enabled. Input image pixels no-data values will not be checked."); if (HasValue("mask")) - { + { otbAppLogINFO("Using the provided \"mask\" parameter."); inputImage = GetParameterUInt8Image("mask"); - } + } else - { + { // This is just a hack to not trigger the whole morpho/pad pipeline inputImage = m_NoDataFilter->GetOutput(); readInput = false; - } } + } else - { + { inputImage = m_MorphoFilter->GetOutput(); // Offset update because the morpho filter pads the input image with 1 pixel border userOffX += 1; userOffY += 1; - } + } UInt8ImageType::RegionType entireRegion = inputImage->GetLargestPossibleRegion(); entireRegion.ShrinkByRadius(m_Radius); - m_StreamingManager->PrepareStreaming(inputImage, entireRegion ); + m_StreamingManager->PrepareStreaming(inputImage, entireRegion); UInt8ImageType::IndexType start; start[0] = m_Radius[0] + 1; start[1] = m_Radius[1] + 1; - int m_NumberOfDivisions = m_StreamingManager->GetNumberOfSplits(); - UInt8ImageType::IndexType pos; + int m_NumberOfDivisions = m_StreamingManager->GetNumberOfSplits(); + UInt8ImageType::IndexType pos; UInt8ImageType::IndexValueType step = GetParameterInt("grid.step"); pos.Fill(0); // Offset update - userOffX %= step ; - userOffY %= step ; + userOffX %= step; + userOffY %= step; for (int m_CurrentDivision = 0; m_CurrentDivision < m_NumberOfDivisions; m_CurrentDivision++) { @@ -304,7 +319,7 @@ public: UInt8ImageType::RegionType streamRegion = m_StreamingManager->GetSplit(m_CurrentDivision); tf::PropagateRequestedRegion<UInt8ImageType>(inputImage, streamRegion); - itk::ImageRegionConstIterator<UInt8ImageType> inIt (inputImage, streamRegion); + itk::ImageRegionConstIterator<UInt8ImageType> inIt(inputImage, streamRegion); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { @@ -327,16 +342,19 @@ public: // Compute coordinates UInt8ImageType::PointType geo; inputImage->TransformIndexToPhysicalPoint(inIt.GetIndex(), geo); - DataNodeType::PointType point; - point[0] = geo[0]; - point[1] = geo[1]; + + // Update geo if we want the corner or the center + if (isEven) + { + geo[0] -= 0.5 * std::abs(inputImage->GetSpacing()[0]); + geo[1] -= 0.5 * std::abs(inputImage->GetSpacing()[1]); + } // Lambda call lambda(pos, geo); } } } - } } @@ -348,21 +366,24 @@ public: { // Nb of samples (maximum) const UInt8ImageType::RegionType entireRegion = m_NoDataFilter->GetOutput()->GetLargestPossibleRegion(); - const unsigned int maxNbOfCols = std::ceil(entireRegion.GetSize(0)/GetParameterInt("grid.step")) + 1; - const unsigned int maxNbOfRows = std::ceil(entireRegion.GetSize(1)/GetParameterInt("grid.step")) + 1; - unsigned int maxNbOfSamples = 1; + const unsigned int maxNbOfCols = std::ceil(entireRegion.GetSize(0) / GetParameterInt("grid.step")) + 1; + const unsigned int maxNbOfRows = std::ceil(entireRegion.GetSize(1) / GetParameterInt("grid.step")) + 1; + unsigned int maxNbOfSamples = 1; maxNbOfSamples *= maxNbOfCols; maxNbOfSamples *= maxNbOfRows; // Nb of classes - SampleBundle initSB(nbOfClasses); + SampleBundle initSB(nbOfClasses); std::vector<SampleBundle> bundles(maxNbOfSamples, initSB); return bundles; } - void SetBlackOrWhiteBundle(SampleBundle & bundle, unsigned int & count, - const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) + void + SetBlackOrWhiteBundle(SampleBundle & bundle, + unsigned int & count, + const UInt8ImageType::IndexType & pos, + const UInt8ImageType::PointType & geo) { // Black or white bool black = ((pos[0] + pos[1]) % 2 == 0); @@ -372,20 +393,20 @@ public: bundle.GetModifiableBlack() = black; bundle.GetModifiableIndex() = pos; count++; - } /* * Samples are placed at regular intervals */ - void SampleChessboard() + void + SampleChessboard() { std::vector<SampleBundle> bundles = AllocateSamples(); unsigned int count = 0; - auto lambda = [this, &count, &bundles] - (const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) { + auto lambda = [this, &count, &bundles](const UInt8ImageType::IndexType & pos, + const UInt8ImageType::PointType & geo) { SetBlackOrWhiteBundle(bundles[count], count, pos, geo); }; @@ -396,7 +417,8 @@ public: PopulateVectorData(bundles); } - void SampleBalanced() + void + SampleBalanced() { // 1. Compute distribution of all samples @@ -409,12 +431,13 @@ public: UInt8ImageType::SizeType patchSize; patchSize.Fill(GetParameterInt("grid.psize")); unsigned int count = 0; - auto lambda = [this, &bundles, &patchSize, &count] - (const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) { - + auto lambda = [this, &bundles, &patchSize, &count](const UInt8ImageType::IndexType & pos, + const UInt8ImageType::PointType & geo) { // Update this sample distribution if (tf::UpdateDistributionFromPatch<UInt8ImageType>(GetParameterUInt8Image("strategy.balanced.labelimage"), - geo, patchSize, bundles[count].GetModifiableDistribution())) + geo, + patchSize, + bundles[count].GetModifiableDistribution())) { SetBlackOrWhiteBundle(bundles[count], count, pos, geo); } @@ -423,7 +446,7 @@ public: Apply(lambda); bundles.resize(count); - otbAppLogINFO("Total number of candidates: " << count ); + otbAppLogINFO("Total number of candidates: " << count); // 2. Seed = spatially random samples @@ -433,13 +456,13 @@ public: otbAppLogINFO("Spatial sampling step " << samplingStep); - float step = 0; + float step = 0; std::vector<SampleBundle> seed(count); std::vector<SampleBundle> candidates(count); unsigned int seedCount = 0; unsigned int candidatesCount = 0; - for (auto& d: bundles) + for (auto & d : bundles) { if (d.GetIndex()[0] % samplingStep + d.GetIndex()[1] % samplingStep == 0) { @@ -465,18 +488,19 @@ public: float removalRate = static_cast<float>(seedCount) / static_cast<float>(nbToRemove); float removalStep = 0; - auto removeSamples = [&removalStep, &removalRate](SampleBundle & b) -> bool { - (void) b; + auto removeSamples = [&removalStep, &removalRate](SampleBundle & b) -> bool { + (void)b; bool ret = false; if (removalStep >= removalRate) - { + { removalStep = fmod(removalStep, removalRate); ret = true; - } + } else ret = false; removalStep++; - return ret;; + return ret; + ; }; auto iterator = std::remove_if(seed.begin(), seed.end(), removeSamples); seed.erase(iterator, seed.end()); @@ -486,8 +510,8 @@ public: // 3. Compute seed distribution const unsigned int nbOfClasses = GetParameterInt("strategy.balanced.nclasses"); - DistributionType seedDist(nbOfClasses); - for (auto& d: seed) + DistributionType seedDist(nbOfClasses); + for (auto & d : seed) seedDist.Update(d.GetDistribution()); otbAppLogINFO("Spatial seed distribution: " << seedDist.ToString()); @@ -497,16 +521,16 @@ public: otbAppLogINFO("Balance seed candidates size: " << candidates.size()); // Sort by cos - auto comparator = [&seedDist](const SampleBundle & a, const SampleBundle & b) -> bool{ + auto comparator = [&seedDist](const SampleBundle & a, const SampleBundle & b) -> bool { return a.GetDistribution().Cosinus(seedDist) > b.GetDistribution().Cosinus(seedDist); }; sort(candidates.begin(), candidates.end(), comparator); DistributionType idealDist(nbOfClasses, 1.0 / std::sqrt(static_cast<float>(nbOfClasses))); - float minCos = 0; - unsigned int samplesAdded = 0; - seed.resize(seed.size()+candidates.size(), SampleBundle(nbOfClasses)); - while(candidates.size() > 0) + float minCos = 0; + unsigned int samplesAdded = 0; + seed.resize(seed.size() + candidates.size(), SampleBundle(nbOfClasses)); + while (candidates.size() > 0) { // Get the less correlated sample SampleBundle candidate = candidates.back(); @@ -538,22 +562,23 @@ public: PopulateVectorData(seed); } - void PopulateVectorData(std::vector<SampleBundle> & samples) + void + PopulateVectorData(const std::vector<SampleBundle> & samples) { // Get data tree DataTreeType::Pointer treeTrain = m_OutVectorDataTrain->GetDataTree(); DataTreeType::Pointer treeValid = m_OutVectorDataValid->GetDataTree(); - DataNodePointer rootTrain = treeTrain->GetRoot()->Get(); - DataNodePointer rootValid = treeValid->GetRoot()->Get(); - DataNodePointer documentTrain = DataNodeType::New(); - DataNodePointer documentValid = DataNodeType::New(); + DataNodePointer rootTrain = treeTrain->GetRoot()->Get(); + DataNodePointer rootValid = treeValid->GetRoot()->Get(); + DataNodePointer documentTrain = DataNodeType::New(); + DataNodePointer documentValid = DataNodeType::New(); documentTrain->SetNodeType(DOCUMENT); documentValid->SetNodeType(DOCUMENT); treeTrain->Add(documentTrain, rootTrain); treeValid->Add(documentValid, rootValid); unsigned int id = 0; - for (const auto& sample: samples) + for (const auto & sample : samples) { // Add point to the VectorData tree DataNodePointer newDataNode = DataNodeType::New(); @@ -572,11 +597,11 @@ public: // Valid treeValid->Add(newDataNode, documentValid); } - } } - void DoExecute() + void + DoExecute() { otbAppLogINFO("Grid step : " << this->GetParameterInt("grid.step")); otbAppLogINFO("Patch size : " << this->GetParameterInt("grid.psize")); @@ -590,7 +615,7 @@ public: // If mask available, use it if (HasValue("mask")) - { + { if (GetParameterUInt8Image("mask")->GetLargestPossibleRegion().GetSize() != GetParameterFloatVectorImage("in")->GetLargestPossibleRegion().GetSize()) otbAppLogFATAL("Mask must have the same size as the input image!"); @@ -599,24 +624,24 @@ public: m_MaskImageFilter->SetMaskImage(GetParameterUInt8Image("mask")); m_MaskImageFilter->UpdateOutputInformation(); src = m_MaskImageFilter->GetOutput(); - } + } // Padding 1 pixel UInt8ImageType::SizeType size = src->GetLargestPossibleRegion().GetSize(); size[0] += 2; size[1] += 2; UInt8ImageType::SpacingType spacing = src->GetSignedSpacing(); - UInt8ImageType::PointType origin = src->GetOrigin(); + UInt8ImageType::PointType origin = src->GetOrigin(); origin[0] -= spacing[0]; origin[1] -= spacing[1]; m_PadFilter = PadFilterType::New(); NNInterpolatorType::Pointer nnInterpolator = NNInterpolatorType::New(); m_PadFilter->SetInterpolator(nnInterpolator); - m_PadFilter->SetInput( src ); + m_PadFilter->SetInput(src); m_PadFilter->SetOutputOrigin(origin); m_PadFilter->SetOutputSpacing(spacing); m_PadFilter->SetOutputSize(size); - m_PadFilter->SetEdgePaddingValue( 0 ); + m_PadFilter->SetEdgePaddingValue(0); m_PadFilter->UpdateOutputInformation(); // Morpho @@ -649,13 +674,17 @@ public: SampleBalanced(); } - otbAppLogINFO( "Writing output samples positions"); + otbAppLogINFO("Writing output samples positions"); SetParameterOutputVectorData("outtrain", m_OutVectorDataTrain); SetParameterOutputVectorData("outvalid", m_OutVectorDataValid); - } + + void + DoUpdateParameters() + {} + private: RadiusType m_Radius; IsNoDataFilterType::Pointer m_NoDataFilter; @@ -666,7 +695,7 @@ private: MaskImageFilterType::Pointer m_MaskImageFilter; }; // end of class -} // end namespace wrapper +} // namespace Wrapper } // end namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::PatchesSelection ) +OTB_APPLICATION_EXPORT(otb::Wrapper::PatchesSelection) diff --git a/app/otbTensorflowModelServe.cxx b/app/otbTensorflowModelServe.cxx index 0aa9d71ec29345ecd332062d030e3ca0ea70eab6..b9f74dfc71d90bd17226deffd60a79b6039600e6 100644 --- a/app/otbTensorflowModelServe.cxx +++ b/app/otbTensorflowModelServe.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -16,9 +17,8 @@ #include "otbStandardFilterWatcher.h" #include "itkFixedArray.h" -// Tensorflow stuff -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/platform/env.h" +// Tensorflow SavedModel +#include "tensorflow/cc/saved_model/loader.h" // Tensorflow model filter #include "otbTensorflowMultisourceModelFilter.h" @@ -42,10 +42,10 @@ class TensorflowModelServe : public Application { public: /** Standard class typedefs. */ - typedef TensorflowModelServe Self; - typedef Application Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef TensorflowModelServe Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); @@ -53,19 +53,15 @@ public: /** Typedefs for tensorflow */ typedef otb::TensorflowMultisourceModelFilter<FloatVectorImageType, FloatVectorImageType> TFModelFilterType; - typedef otb::TensorflowSource<FloatVectorImageType> InputImageSource; + typedef otb::TensorflowSource<FloatVectorImageType> InputImageSource; /** Typedef for streaming */ - typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType; + typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType; typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType; /** Typedefs for images */ typedef FloatVectorImageType::SizeType SizeType; - void DoUpdateParameters() - { - } - // // Store stuff related to one source // @@ -87,147 +83,164 @@ public: // -an input image list // -an input patchsize (dimensions of samples) // - void AddAnInputImage() + void + AddAnInputImage() { // Number of source unsigned int inputNumber = m_Bundles.size() + 1; // Create keys and descriptions - std::stringstream ss_key_group, ss_desc_group, - ss_key_in, ss_desc_in, - ss_key_dims_x, ss_desc_dims_x, - ss_key_dims_y, ss_desc_dims_y, - ss_key_ph, ss_desc_ph; + std::stringstream ss_key_group, ss_desc_group, ss_key_in, ss_desc_in, ss_key_dims_x, ss_desc_dims_x, ss_key_dims_y, + ss_desc_dims_y, ss_key_ph, ss_desc_ph; // Parameter group key/description - ss_key_group << "source" << inputNumber; + ss_key_group << "source" << inputNumber; ss_desc_group << "Parameters for source #" << inputNumber; // Parameter group keys - ss_key_in << ss_key_group.str() << ".il"; - ss_key_dims_x << ss_key_group.str() << ".rfieldx"; - ss_key_dims_y << ss_key_group.str() << ".rfieldy"; - ss_key_ph << ss_key_group.str() << ".placeholder"; + ss_key_in << ss_key_group.str() << ".il"; + ss_key_dims_x << ss_key_group.str() << ".rfieldx"; + ss_key_dims_y << ss_key_group.str() << ".rfieldy"; + ss_key_ph << ss_key_group.str() << ".placeholder"; // Parameter group descriptions - ss_desc_in << "Input image (or list to stack) for source #" << inputNumber; - ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber; + ss_desc_in << "Input image (or list to stack) for source #" << inputNumber; + ss_desc_dims_x << "Input receptive field (width) for source #" << inputNumber; ss_desc_dims_y << "Input receptive field (height) for source #" << inputNumber; - ss_desc_ph << "Name of the input placeholder for source #" << inputNumber; + ss_desc_ph << "Name of the input placeholder for source #" << inputNumber; // Populate group - AddParameter(ParameterType_Group, ss_key_group.str(), ss_desc_group.str()); - AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str() ); - AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); - SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); - AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); - SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); - AddParameter(ParameterType_String, ss_key_ph.str(), ss_desc_ph.str()); + AddParameter(ParameterType_Group, ss_key_group.str(), ss_desc_group.str()); + AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str()); + AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); + SetMinimumParameterIntValue(ss_key_dims_x.str(), 1); + SetDefaultParameterInt(ss_key_dims_x.str(), 1); + AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); + SetMinimumParameterIntValue(ss_key_dims_y.str(), 1); + SetDefaultParameterInt(ss_key_dims_y.str(), 1); + AddParameter(ParameterType_String, ss_key_ph.str(), ss_desc_ph.str()); + MandatoryOff(ss_key_ph.str()); // Add a new bundle ProcessObjectsBundle bundle; - bundle.m_KeyIn = ss_key_in.str(); - bundle.m_KeyPszX = ss_key_dims_x.str(); - bundle.m_KeyPszY = ss_key_dims_y.str(); + bundle.m_KeyIn = ss_key_in.str(); + bundle.m_KeyPszX = ss_key_dims_x.str(); + bundle.m_KeyPszY = ss_key_dims_y.str(); bundle.m_KeyPHName = ss_key_ph.str(); m_Bundles.push_back(bundle); - } - void DoInit() + void + DoInit() { // Documentation SetName("TensorflowModelServe"); - SetDescription("Multisource deep learning classifier using TensorFlow. Change the " - + tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of sources."); + SetDescription("Multisource deep learning classifier using TensorFlow. Change the " + tf::ENV_VAR_NAME_NSOURCES + + " environment variable to set the number of sources."); SetDocLongDescription("The application run a TensorFlow model over multiple data sources. " - "The number of input sources can be changed at runtime by setting the system " - "environment variable " + tf::ENV_VAR_NAME_NSOURCES + ". For each source, you have to " - "set (1) the placeholder name, as named in the TensorFlow model, (2) the receptive " - "field and (3) the image(s) source. The output is a multiband image, stacking all " - "outputs tensors together: you have to specify (1) the names of the output tensors, as " - "named in the TensorFlow model (typically, an operator's output) and (2) the expression " - "field of each output tensor. The output tensors values will be stacked in the same " - "order as they appear in the \"model.output\" parameter (you can use a space separator " - "between names). You might consider to use extended filename to bypass the automatic " - "memory footprint calculator of the otb application engine, and set a good splitting " - "strategy (Square tiles is good for convolutional networks) or use the \"optim\" " - "parameter group to impose your squared tiles sizes"); + "The number of input sources can be changed at runtime by setting the system " + "environment variable " + + tf::ENV_VAR_NAME_NSOURCES + + ". For each source, you have to " + "set (1) the placeholder name, as named in the TensorFlow model, (2) the receptive " + "field and (3) the image(s) source. The output is a multiband image, stacking all " + "outputs tensors together: you have to specify (1) the names of the output tensors, as " + "named in the TensorFlow model (typically, an operator's output) and (2) the expression " + "field of each output tensor. The output tensors values will be stacked in the same " + "order as they appear in the \"model.output\" parameter (you can use a space separator " + "between names). You might consider to use extended filename to bypass the automatic " + "memory footprint calculator of the otb application engine, and set a good splitting " + "strategy (Square tiles is good for convolutional networks) or use the \"optim\" " + "parameter group to impose your squared tiles sizes"); SetDocAuthors("Remi Cresson"); AddDocTag(Tags::Learning); // Input/output images AddAnInputImage(); - for (int i = 1; i < tf::GetNumberOfSources() ; i++) + for (int i = 1; i < tf::GetNumberOfSources(); i++) AddAnInputImage(); // Input model - AddParameter(ParameterType_Group, "model", "model parameters"); - AddParameter(ParameterType_Directory, "model.dir", "TensorFlow model_save directory"); - MandatoryOn ("model.dir"); - SetParameterDescription ("model.dir", "The model directory should contains the model Google Protobuf (.pb) and variables"); - - AddParameter(ParameterType_StringList, "model.userplaceholders", "Additional single-valued placeholders. Supported types: int, float, bool."); - MandatoryOff ("model.userplaceholders"); - SetParameterDescription ("model.userplaceholders", "Syntax to use is \"placeholder_1=value_1 ... placeholder_N=value_N\""); - AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional"); - MandatoryOff ("model.fullyconv"); + AddParameter(ParameterType_Group, "model", "model parameters"); + AddParameter(ParameterType_Directory, "model.dir", "TensorFlow SavedModel directory"); + MandatoryOn("model.dir"); + SetParameterDescription("model.dir", + "The model directory should contains the model Google Protobuf (.pb) and variables"); + + AddParameter(ParameterType_StringList, + "model.userplaceholders", + "Additional single-valued placeholders. Supported types: int, float, bool."); + MandatoryOff("model.userplaceholders"); + SetParameterDescription("model.userplaceholders", + "Syntax to use is \"placeholder_1=value_1 ... placeholder_N=value_N\""); + AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional"); + MandatoryOff("model.fullyconv"); + AddParameter(ParameterType_StringList, + "model.tagsets", + "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Currently, only one tag is " + "supported. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`"); + MandatoryOff("model.tagsets"); // Output tensors parameters - AddParameter(ParameterType_Group, "output", "Output tensors parameters"); - AddParameter(ParameterType_Float, "output.spcscale", "The output spacing scale, related to the first input"); - SetDefaultParameterFloat ("output.spcscale", 1.0); - SetParameterDescription ("output.spcscale", "The output image size/scale and spacing*scale where size and spacing corresponds to the first input"); - AddParameter(ParameterType_StringList, "output.names", "Names of the output tensors"); - MandatoryOn ("output.names"); + AddParameter(ParameterType_Group, "output", "Output tensors parameters"); + AddParameter(ParameterType_Float, "output.spcscale", "The output spacing scale, related to the first input"); + SetDefaultParameterFloat("output.spcscale", 1.0); + SetParameterDescription( + "output.spcscale", + "The output image size/scale and spacing*scale where size and spacing corresponds to the first input"); + AddParameter(ParameterType_StringList, "output.names", "Names of the output tensors"); + MandatoryOff("output.names"); // Output Field of Expression - AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)"); - SetMinimumParameterIntValue ("output.efieldx", 1); - SetDefaultParameterInt ("output.efieldx", 1); - MandatoryOn ("output.efieldx"); - AddParameter(ParameterType_Int, "output.efieldy", "The output expression field (height)"); - SetMinimumParameterIntValue ("output.efieldy", 1); - SetDefaultParameterInt ("output.efieldy", 1); - MandatoryOn ("output.efieldy"); + AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)"); + SetMinimumParameterIntValue("output.efieldx", 1); + SetDefaultParameterInt("output.efieldx", 1); + MandatoryOn("output.efieldx"); + AddParameter(ParameterType_Int, "output.efieldy", "The output expression field (height)"); + SetMinimumParameterIntValue("output.efieldy", 1); + SetDefaultParameterInt("output.efieldy", 1); + MandatoryOn("output.efieldy"); // Fine tuning - AddParameter(ParameterType_Group, "optim" , "This group of parameters allows optimization of processing time"); - AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling"); - MandatoryOff ("optim.disabletiling"); - SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it"); - AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output"); - SetMinimumParameterIntValue ("optim.tilesizex", 1); - SetDefaultParameterInt ("optim.tilesizex", 16); - AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output"); - SetMinimumParameterIntValue ("optim.tilesizey", 1); - SetDefaultParameterInt ("optim.tilesizey", 16); + AddParameter(ParameterType_Group, "optim", "This group of parameters allows optimization of processing time"); + AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling"); + MandatoryOff("optim.disabletiling"); + SetParameterDescription( + "optim.disabletiling", + "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it"); + AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output"); + SetMinimumParameterIntValue("optim.tilesizex", 1); + SetDefaultParameterInt("optim.tilesizex", 16); + AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output"); + SetMinimumParameterIntValue("optim.tilesizey", 1); + SetDefaultParameterInt("optim.tilesizey", 16); // Output image AddParameter(ParameterType_OutputImage, "out", "output image"); // Example - SetDocExampleParameterValue("source1.il", "spot6pms.tif"); - SetDocExampleParameterValue("source1.placeholder", "x1"); - SetDocExampleParameterValue("source1.rfieldx", "16"); - SetDocExampleParameterValue("source1.rfieldy", "16"); - SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); + SetDocExampleParameterValue("source1.il", "spot6pms.tif"); + SetDocExampleParameterValue("source1.placeholder", "x1"); + SetDocExampleParameterValue("source1.rfieldx", "16"); + SetDocExampleParameterValue("source1.rfieldy", "16"); + SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); SetDocExampleParameterValue("model.userplaceholders", "is_training=false dropout=0.0"); - SetDocExampleParameterValue("output.names", "out_predict1 out_proba1"); - SetDocExampleParameterValue("out", "\"classif128tgt.tif?&streaming:type=tiled&streaming:sizemode=height&streaming:sizevalue=256\""); - + SetDocExampleParameterValue("output.names", "out_predict1 out_proba1"); + SetDocExampleParameterValue( + "out", "\"classif128tgt.tif?&streaming:type=tiled&streaming:sizemode=height&streaming:sizevalue=256\""); } // // Prepare bundles from the number of points // - void PrepareInputs() + void + PrepareInputs() { - for (auto& bundle: m_Bundles) + for (auto & bundle : m_Bundles) { // Setting the image source FloatVectorImageListType::Pointer list = GetParameterImageList(bundle.m_KeyIn); @@ -237,32 +250,32 @@ public: bundle.m_PatchSize[1] = GetParameterInt(bundle.m_KeyPszY); otbAppLogINFO("Source info :"); - otbAppLogINFO("Receptive field : " << bundle.m_PatchSize ); + otbAppLogINFO("Receptive field : " << bundle.m_PatchSize); otbAppLogINFO("Placeholder name : " << bundle.m_Placeholder); } } - void DoExecute() + void + DoExecute() { // Load the Tensorflow bundle - tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel); + tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, GetParameterStringList("model.tagsets")); // Prepare inputs PrepareInputs(); // Setup filter m_TFFilter = TFModelFilterType::New(); - m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); - m_TFFilter->SetSession(m_SavedModel.session.get()); + m_TFFilter->SetSavedModel(&m_SavedModel); m_TFFilter->SetOutputTensors(GetParameterStringList("output.names")); m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale")); otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale()); // Get user placeholders TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders"); - TFModelFilterType::DictType dict; - for (auto& exp: expressions) + TFModelFilterType::DictType dict; + for (auto & exp : expressions) { TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp); dict.push_back(entry); @@ -272,13 +285,13 @@ public: m_TFFilter->SetUserPlaceholders(dict); // Input sources - for (auto& bundle: m_Bundles) + for (auto & bundle : m_Bundles) { m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get()); } // Fully convolutional mode on/off - if (GetParameterInt("model.fullyconv")==1) + if (GetParameterInt("model.fullyconv") == 1) { otbAppLogINFO("The TensorFlow model is used in fully convolutional mode"); m_TFFilter->SetFullyConvolutional(true); @@ -288,7 +301,7 @@ public: FloatVectorImageType::SizeType foe; foe[0] = GetParameterInt("output.efieldx"); foe[1] = GetParameterInt("output.efieldy"); - m_TFFilter->SetOutputExpressionFields({foe}); + m_TFFilter->SetOutputExpressionFields({ foe }); otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]); @@ -301,22 +314,22 @@ public: tileSize[1] = GetParameterInt("optim.tilesizey"); // Check that the tile size is aligned to the field of expression - for (unsigned int i = 0 ; i < FloatVectorImageType::ImageDimension ; i++) + for (unsigned int i = 0; i < FloatVectorImageType::ImageDimension; i++) if (tileSize[i] % foe[i] != 0) - { + { SizeType::SizeValueType newSize = 1 + std::floor(tileSize[i] / foe[i]); newSize *= foe[i]; otbAppLogWARNING("Aligning the tiling to the output expression field " - << "for better performances (dim " << i << "). New value set to " << newSize) + << "for better performances (dim " << i << "). New value set to " << newSize) - tileSize[i] = newSize; - } + tileSize[i] = newSize; + } otbAppLogINFO("Force tiling with squared tiles of " << tileSize) - // Force the computation tile by tile - m_StreamFilter = StreamingFilterType::New(); + // Force the computation tile by tile + m_StreamFilter = StreamingFilterType::New(); m_StreamFilter->SetOutputGridSize(tileSize); m_StreamFilter->SetInput(m_TFFilter->GetOutput()); @@ -329,17 +342,21 @@ public: } } -private: + void + DoUpdateParameters() + {} + +private: TFModelFilterType::Pointer m_TFFilter; StreamingFilterType::Pointer m_StreamFilter; tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application ! - std::vector<ProcessObjectsBundle> m_Bundles; + std::vector<ProcessObjectsBundle> m_Bundles; }; // end of class -} // namespace wrapper +} // namespace Wrapper } // namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::TensorflowModelServe ) +OTB_APPLICATION_EXPORT(otb::Wrapper::TensorflowModelServe) diff --git a/app/otbTensorflowModelTrain.cxx b/app/otbTensorflowModelTrain.cxx index b37c72c3df7a297dda0857030d47bbe1e08cad61..f5a420a9625a03c086df81a1d43db78f75be4e0a 100644 --- a/app/otbTensorflowModelTrain.cxx +++ b/app/otbTensorflowModelTrain.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -16,9 +17,8 @@ #include "otbStandardFilterWatcher.h" #include "itkFixedArray.h" -// Tensorflow stuff -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/platform/env.h" +// Tensorflow SavedModel +#include "tensorflow/cc/saved_model/loader.h" // Tensorflow model train #include "otbTensorflowMultisourceModelTrain.h" @@ -42,12 +42,11 @@ namespace Wrapper class TensorflowModelTrain : public Application { public: - /** Standard class typedefs. */ - typedef TensorflowModelTrain Self; - typedef Application Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef TensorflowModelTrain Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); @@ -77,8 +76,8 @@ public: std::string m_KeyInForValid; // Key of input image list (validation) std::string m_KeyPHNameForTrain; // Key for placeholder name in the TensorFlow model (training) std::string m_KeyPHNameForValid; // Key for placeholder name in the TensorFlow model (validation) - std::string m_KeyPszX; // Key for samples sizes X - std::string m_KeyPszY; // Key for samples sizes Y + std::string m_KeyPszX; // Key for samples sizes X + std::string m_KeyPszY; // Key for samples sizes Y }; /** Typedefs for the app */ @@ -86,9 +85,9 @@ public: typedef std::vector<FloatVectorImageType::SizeType> SizeList; typedef std::vector<std::string> StringList; - void DoUpdateParameters() - { - } + void + DoUpdateParameters() + {} // // Add an input source, which includes: @@ -98,149 +97,161 @@ public: // -an input image placeholder (for validation) // -an input patchsize, which is the dimensions of samples. Same for training and validation. // - void AddAnInputImage() + void + AddAnInputImage() { // Number of source unsigned int inputNumber = m_Bundles.size() + 1; // Create keys and descriptions - std::stringstream ss_key_tr_group, ss_desc_tr_group, - ss_key_val_group, ss_desc_val_group, - ss_key_tr_in, ss_desc_tr_in, - ss_key_val_in, ss_desc_val_in, - ss_key_dims_x, ss_desc_dims_x, - ss_key_dims_y, ss_desc_dims_y, - ss_key_tr_ph, ss_desc_tr_ph, - ss_key_val_ph, ss_desc_val_ph; + std::stringstream ss_key_tr_group, ss_desc_tr_group, ss_key_val_group, ss_desc_val_group, ss_key_tr_in, + ss_desc_tr_in, ss_key_val_in, ss_desc_val_in, ss_key_dims_x, ss_desc_dims_x, ss_key_dims_y, ss_desc_dims_y, + ss_key_tr_ph, ss_desc_tr_ph, ss_key_val_ph, ss_desc_val_ph; // Parameter group key/description - ss_key_tr_group << "training.source" << inputNumber; - ss_key_val_group << "validation.source" << inputNumber; - ss_desc_tr_group << "Parameters for source #" << inputNumber << " (training)"; + ss_key_tr_group << "training.source" << inputNumber; + ss_key_val_group << "validation.source" << inputNumber; + ss_desc_tr_group << "Parameters for source #" << inputNumber << " (training)"; ss_desc_val_group << "Parameters for source #" << inputNumber << " (validation)"; // Parameter group keys - ss_key_tr_in << ss_key_tr_group.str() << ".il"; - ss_key_val_in << ss_key_val_group.str() << ".il"; - ss_key_dims_x << ss_key_tr_group.str() << ".patchsizex"; - ss_key_dims_y << ss_key_tr_group.str() << ".patchsizey"; - ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder"; - ss_key_val_ph << ss_key_val_group.str() << ".name"; + ss_key_tr_in << ss_key_tr_group.str() << ".il"; + ss_key_val_in << ss_key_val_group.str() << ".il"; + ss_key_dims_x << ss_key_tr_group.str() << ".patchsizex"; + ss_key_dims_y << ss_key_tr_group.str() << ".patchsizey"; + ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder"; + ss_key_val_ph << ss_key_val_group.str() << ".name"; // Parameter group descriptions - ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)"; + ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)"; ss_desc_val_in << "Input image (or list to stack) for source #" << inputNumber << " (validation)"; - ss_desc_dims_x << "Patch size (x) for source #" << inputNumber; - ss_desc_dims_y << "Patch size (y) for source #" << inputNumber; - ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)"; + ss_desc_dims_x << "Patch size (x) for source #" << inputNumber; + ss_desc_dims_y << "Patch size (y) for source #" << inputNumber; + ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)"; ss_desc_val_ph << "Name of the input placeholder " - "or output tensor for source #" << inputNumber << " (validation)"; + "or output tensor for source #" + << inputNumber << " (validation)"; // Populate group - AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str()); - AddParameter(ParameterType_InputImageList, ss_key_tr_in.str(), ss_desc_tr_in.str() ); - AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); - SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); - AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); - SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); - AddParameter(ParameterType_String, ss_key_tr_ph.str(), ss_desc_tr_ph.str()); - AddParameter(ParameterType_Group, ss_key_val_group.str(), ss_desc_val_group.str()); - AddParameter(ParameterType_InputImageList, ss_key_val_in.str(), ss_desc_val_in.str() ); - AddParameter(ParameterType_String, ss_key_val_ph.str(), ss_desc_val_ph.str()); + AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str()); + AddParameter(ParameterType_InputImageList, ss_key_tr_in.str(), ss_desc_tr_in.str()); + AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); + SetMinimumParameterIntValue(ss_key_dims_x.str(), 1); + AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); + SetMinimumParameterIntValue(ss_key_dims_y.str(), 1); + AddParameter(ParameterType_String, ss_key_tr_ph.str(), ss_desc_tr_ph.str()); + AddParameter(ParameterType_Group, ss_key_val_group.str(), ss_desc_val_group.str()); + AddParameter(ParameterType_InputImageList, ss_key_val_in.str(), ss_desc_val_in.str()); + AddParameter(ParameterType_String, ss_key_val_ph.str(), ss_desc_val_ph.str()); // Add a new bundle ProcessObjectsBundle bundle; - bundle.m_KeyInForTrain = ss_key_tr_in.str(); - bundle.m_KeyInForValid = ss_key_val_in.str(); + bundle.m_KeyInForTrain = ss_key_tr_in.str(); + bundle.m_KeyInForValid = ss_key_val_in.str(); bundle.m_KeyPHNameForTrain = ss_key_tr_ph.str(); bundle.m_KeyPHNameForValid = ss_key_val_ph.str(); - bundle.m_KeyPszX = ss_key_dims_x.str(); - bundle.m_KeyPszY = ss_key_dims_y.str(); + bundle.m_KeyPszX = ss_key_dims_x.str(); + bundle.m_KeyPszY = ss_key_dims_y.str(); m_Bundles.push_back(bundle); } - void DoInit() + void + DoInit() { // Documentation SetName("TensorflowModelTrain"); SetDescription("Train a multisource deep learning net using Tensorflow. Change " - "the " + tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of " - "sources."); + "the " + + tf::ENV_VAR_NAME_NSOURCES + + " environment variable to set the number of " + "sources."); SetDocLongDescription("The application trains a Tensorflow model over multiple data sources. " - "The number of input sources can be changed at runtime by setting the " - "system environment variable " + tf::ENV_VAR_NAME_NSOURCES + ". " - "For each source, you have to set (1) the tensor placeholder name, as named in " - "the tensorflow model, (2) the patch size and (3) the image(s) source. "); + "The number of input sources can be changed at runtime by setting the " + "system environment variable " + + tf::ENV_VAR_NAME_NSOURCES + + ". " + "For each source, you have to set (1) the tensor placeholder name, as named in " + "the tensorflow model, (2) the patch size and (3) the image(s) source. "); SetDocAuthors("Remi Cresson"); AddDocTag(Tags::Learning); // Input model - AddParameter(ParameterType_Group, "model", "Model parameters"); - AddParameter(ParameterType_Directory, "model.dir", "Tensorflow model_save directory"); - MandatoryOn ("model.dir"); - AddParameter(ParameterType_String, "model.restorefrom", "Restore model from path"); - MandatoryOff ("model.restorefrom"); - AddParameter(ParameterType_String, "model.saveto", "Save model to path"); - MandatoryOff ("model.saveto"); + AddParameter(ParameterType_Group, "model", "Model parameters"); + AddParameter(ParameterType_Directory, "model.dir", "Tensorflow model_save directory"); + MandatoryOn("model.dir"); + AddParameter(ParameterType_String, "model.restorefrom", "Restore model from path"); + MandatoryOff("model.restorefrom"); + AddParameter(ParameterType_String, "model.saveto", "Save model to path"); + MandatoryOff("model.saveto"); + AddParameter(ParameterType_StringList, + "model.tagsets", + "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Currently, only one tag is " + "supported. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`"); + MandatoryOff("model.tagsets"); // Training parameters group - AddParameter(ParameterType_Group, "training", "Training parameters"); - AddParameter(ParameterType_Int, "training.batchsize", "Batch size"); - SetMinimumParameterIntValue ("training.batchsize", 1); - SetDefaultParameterInt ("training.batchsize", 100); - AddParameter(ParameterType_Int, "training.epochs", "Number of epochs"); - SetMinimumParameterIntValue ("training.epochs", 1); - SetDefaultParameterInt ("training.epochs", 100); - AddParameter(ParameterType_StringList, "training.userplaceholders", + AddParameter(ParameterType_Group, "training", "Training parameters"); + AddParameter(ParameterType_Int, "training.batchsize", "Batch size"); + SetMinimumParameterIntValue("training.batchsize", 1); + SetDefaultParameterInt("training.batchsize", 100); + AddParameter(ParameterType_Int, "training.epochs", "Number of epochs"); + SetMinimumParameterIntValue("training.epochs", 1); + SetDefaultParameterInt("training.epochs", 100); + AddParameter(ParameterType_StringList, + "training.userplaceholders", "Additional single-valued placeholders for training. Supported types: int, float, bool."); - MandatoryOff ("training.userplaceholders"); - AddParameter(ParameterType_StringList, "training.targetnodes", "Names of the target nodes"); - MandatoryOn ("training.targetnodes"); - AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display"); - MandatoryOff ("training.outputtensors"); - AddParameter(ParameterType_Bool, "training.usestreaming", "Use the streaming through patches (slower but can process big dataset)"); - MandatoryOff ("training.usestreaming"); + MandatoryOff("training.userplaceholders"); + AddParameter(ParameterType_StringList, "training.targetnodes", "Names of the target nodes"); + MandatoryOn("training.targetnodes"); + AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display"); + MandatoryOff("training.outputtensors"); + AddParameter(ParameterType_Bool, + "training.usestreaming", + "Use the streaming through patches (slower but can process big dataset)"); + MandatoryOff("training.usestreaming"); // Metrics - AddParameter(ParameterType_Group, "validation", "Validation parameters"); - MandatoryOff ("validation"); - AddParameter(ParameterType_Int, "validation.step", "Perform the validation every Nth epochs"); - SetMinimumParameterIntValue ("validation.step", 1); - SetDefaultParameterInt ("validation.step", 10); - AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute"); - AddChoice ("validation.mode.none", "No validation step"); - AddChoice ("validation.mode.class", "Classification metrics"); - AddChoice ("validation.mode.rmse", "Root mean square error"); - AddParameter(ParameterType_StringList, "validation.userplaceholders", + AddParameter(ParameterType_Group, "validation", "Validation parameters"); + MandatoryOff("validation"); + AddParameter(ParameterType_Int, "validation.step", "Perform the validation every Nth epochs"); + SetMinimumParameterIntValue("validation.step", 1); + SetDefaultParameterInt("validation.step", 10); + AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute"); + AddChoice("validation.mode.none", "No validation step"); + AddChoice("validation.mode.class", "Classification metrics"); + AddChoice("validation.mode.rmse", "Root mean square error"); + AddParameter(ParameterType_StringList, + "validation.userplaceholders", "Additional single-valued placeholders for validation. Supported types: int, float, bool."); - MandatoryOff ("validation.userplaceholders"); - AddParameter(ParameterType_Bool, "validation.usestreaming", "Use the streaming through patches (slower but can process big dataset)"); - MandatoryOff ("validation.usestreaming"); + MandatoryOff("validation.userplaceholders"); + AddParameter(ParameterType_Bool, + "validation.usestreaming", + "Use the streaming through patches (slower but can process big dataset)"); + MandatoryOff("validation.usestreaming"); // Input/output images AddAnInputImage(); - for (int i = 1; i < tf::GetNumberOfSources() + 1 ; i++) // +1 because we have at least 1 source more for training - { + for (int i = 1; i < tf::GetNumberOfSources() + 1; i++) // +1 because we have at least 1 source more for training + { AddAnInputImage(); - } + } // Example - SetDocExampleParameterValue("source1.il", "spot6pms.tif"); - SetDocExampleParameterValue("source1.placeholder", "x1"); - SetDocExampleParameterValue("source1.patchsizex", "16"); - SetDocExampleParameterValue("source1.patchsizey", "16"); - SetDocExampleParameterValue("source2.il", "labels.tif"); - SetDocExampleParameterValue("source2.placeholder", "y1"); - SetDocExampleParameterValue("source2.patchsizex", "1"); - SetDocExampleParameterValue("source2.patchsizex", "1"); - SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); + SetDocExampleParameterValue("source1.il", "spot6pms.tif"); + SetDocExampleParameterValue("source1.placeholder", "x1"); + SetDocExampleParameterValue("source1.patchsizex", "16"); + SetDocExampleParameterValue("source1.patchsizey", "16"); + SetDocExampleParameterValue("source2.il", "labels.tif"); + SetDocExampleParameterValue("source2.placeholder", "y1"); + SetDocExampleParameterValue("source2.patchsizex", "1"); + SetDocExampleParameterValue("source2.patchsizex", "1"); + SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2"); - SetDocExampleParameterValue("training.targetnodes", "optimizer"); - SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model/variables/variables"); - + SetDocExampleParameterValue("training.targetnodes", "optimizer"); + SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model/variables/variables"); } // @@ -259,7 +270,8 @@ public: // if we can keep trace of indices of sources for // training / test / validation // - void PrepareInputs() + void + PrepareInputs() { // Clear placeholder names m_InputPlaceholdersForTraining.clear(); @@ -281,8 +293,8 @@ public: // Prepare the bundles - for (auto& bundle: m_Bundles) - { + for (auto & bundle : m_Bundles) + { // Source FloatVectorImageListType::Pointer trainStack = GetParameterImageList(bundle.m_KeyInForTrain); bundle.tfSource.Set(trainStack); @@ -299,17 +311,17 @@ public: m_InputPatchesSizeForTraining.push_back(patchSize); otbAppLogINFO("New source:"); - otbAppLogINFO("Patch size : "<< patchSize); - otbAppLogINFO("Placeholder (training) : "<< placeholderForTraining); + otbAppLogINFO("Patch size : " << patchSize); + otbAppLogINFO("Placeholder (training) : " << placeholderForTraining); // Prepare validation sources if (GetParameterInt("validation.mode") != 0) - { + { // Get the stack if (!HasValue(bundle.m_KeyInForValid)) - { + { otbAppLogFATAL("No validation input is set for this source"); - } + } FloatVectorImageListType::Pointer validStack = GetParameterImageList(bundle.m_KeyInForValid); bundle.tfSourceForValidation.Set(validStack); @@ -317,12 +329,12 @@ public: // If yes, it means that its not an output tensor on which perform the validation std::string placeholderForValidation = GetParameterAsString(bundle.m_KeyPHNameForValid); if (placeholderForValidation.empty()) - { + { placeholderForValidation = placeholderForTraining; - } + } // Same placeholder name ==> is a source for validation if (placeholderForValidation.compare(placeholderForTraining) == 0) - { + { // Source m_InputSourcesForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get()); m_InputSourcesForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get()); @@ -333,12 +345,11 @@ public: // Patch size m_InputPatchesSizeForValidation.push_back(patchSize); - otbAppLogINFO("Placeholder (validation) : "<< placeholderForValidation); - - } + otbAppLogINFO("Placeholder (validation) : " << placeholderForValidation); + } // Different placeholder ==> is a target to validate else - { + { // Source m_InputTargetsForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get()); m_InputTargetsForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get()); @@ -349,51 +360,54 @@ public: // Patch size m_TargetPatchesSize.push_back(patchSize); - otbAppLogINFO("Tensor name (validation) : "<< placeholderForValidation); - } - + otbAppLogINFO("Tensor name (validation) : " << placeholderForValidation); } - } + } } // // Get user placeholders // - TrainModelFilterType::DictType GetUserPlaceholders(const std::string key) + TrainModelFilterType::DictType + GetUserPlaceholders(const std::string & key) { - TrainModelFilterType::DictType dict; + TrainModelFilterType::DictType dict; TrainModelFilterType::StringList expressions = GetParameterStringList(key); - for (auto& exp: expressions) - { + for (auto & exp : expressions) + { TrainModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp); dict.push_back(entry); otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second)); - } + } return dict; } // // Print some classification metrics // - void PrintClassificationMetrics(const ConfMatType & confMat, const MapOfClassesType & mapOfClassesRef) + void + PrintClassificationMetrics(const ConfMatType & confMat, const MapOfClassesType & mapOfClassesRef) { ConfusionMatrixCalculatorType::Pointer confMatMeasurements = ConfusionMatrixCalculatorType::New(); confMatMeasurements->SetConfusionMatrix(confMat); confMatMeasurements->SetMapOfClasses(mapOfClassesRef); confMatMeasurements->Compute(); - for (auto const& itMapOfClassesRef : mapOfClassesRef) - { + for (auto const & itMapOfClassesRef : mapOfClassesRef) + { LabelValueType labelRef = itMapOfClassesRef.first; LabelValueType indexLabelRef = itMapOfClassesRef.second; - otbAppLogINFO("Precision of class [" << labelRef << "] vs all: " << confMatMeasurements->GetPrecisions()[indexLabelRef]); - otbAppLogINFO("Recall of class [" << labelRef << "] vs all: " << confMatMeasurements->GetRecalls()[indexLabelRef]); - otbAppLogINFO("F-score of class [" << labelRef << "] vs all: " << confMatMeasurements->GetFScores()[indexLabelRef]); + otbAppLogINFO("Precision of class [" << labelRef + << "] vs all: " << confMatMeasurements->GetPrecisions()[indexLabelRef]); + otbAppLogINFO("Recall of class [" << labelRef + << "] vs all: " << confMatMeasurements->GetRecalls()[indexLabelRef]); + otbAppLogINFO("F-score of class [" << labelRef + << "] vs all: " << confMatMeasurements->GetFScores()[indexLabelRef]); otbAppLogINFO("\t"); - } + } otbAppLogINFO("Precision of the different classes: " << confMatMeasurements->GetPrecisions()); otbAppLogINFO("Recall of the different classes: " << confMatMeasurements->GetRecalls()); otbAppLogINFO("F-score of the different classes: " << confMatMeasurements->GetFScores()); @@ -403,27 +417,29 @@ public: otbAppLogINFO("Confusion matrix:\n" << confMat); } - void DoExecute() + void + DoExecute() { // Load the Tensorflow bundle - tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel); + tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, GetParameterStringList("model.tagsets")); - // Check if we have to restore variables from somewhere + // Check if we have to restore variables from somewhere else if (HasValue("model.restorefrom")) - { + { const std::string path = GetParameterAsString("model.restorefrom"); otbAppLogINFO("Restoring model from " + path); + + // Load SavedModel variables tf::RestoreModel(path, m_SavedModel); - } + } // Prepare inputs PrepareInputs(); // Setup training filter m_TrainModelFilter = TrainModelFilterType::New(); - m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); - m_TrainModelFilter->SetSession(m_SavedModel.session.get()); + m_TrainModelFilter->SetSavedModel(&m_SavedModel); m_TrainModelFilter->SetOutputTensors(GetParameterStringList("training.outputtensors")); m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes")); m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); @@ -431,41 +447,38 @@ public: m_TrainModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); // Set inputs - for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++) - { + for (unsigned int i = 0; i < m_InputSourcesForTraining.size(); i++) + { m_TrainModelFilter->PushBackInputTensorBundle( - m_InputPlaceholdersForTraining[i], - m_InputPatchesSizeForTraining[i], - m_InputSourcesForTraining[i]); - } + m_InputPlaceholdersForTraining[i], m_InputPatchesSizeForTraining[i], m_InputSourcesForTraining[i]); + } // Setup the validation filter const bool do_validation = HasUserValue("validation.mode"); - if (GetParameterInt("validation.mode")==1) // class - { + if (GetParameterInt("validation.mode") == 1) // class + { otbAppLogINFO("Set validation mode to classification validation"); m_ValidateModelFilter = ValidateModelFilterType::New(); - m_ValidateModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); - m_ValidateModelFilter->SetSession(m_SavedModel.session.get()); + m_ValidateModelFilter->SetSavedModel(&m_SavedModel); m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); m_ValidateModelFilter->SetInputPlaceholders(m_InputPlaceholdersForValidation); m_ValidateModelFilter->SetInputReceptiveFields(m_InputPatchesSizeForValidation); m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames); m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize); - } - else if (GetParameterInt("validation.mode")==2) // rmse) - { + } + else if (GetParameterInt("validation.mode") == 2) // rmse) + { otbAppLogINFO("Set validation mode to classification RMSE evaluation"); otbAppLogFATAL("Not implemented yet !"); // XD // TODO - } + } // Epoch - for (int epoch = 1 ; epoch <= GetParameterInt("training.epochs") ; epoch++) - { + for (int epoch = 1; epoch <= GetParameterInt("training.epochs"); epoch++) + { // Train the model AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch)); m_TrainModelFilter->Update(); @@ -477,7 +490,7 @@ public: { // 1. Evaluate the metrics against the learning data - for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) + for (unsigned int i = 0; i < m_InputSourcesForEvaluationAgainstLearningData.size(); i++) { m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); } @@ -490,16 +503,17 @@ public: AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); m_ValidateModelFilter->Update(); - for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) + for (unsigned int i = 0; i < m_TargetTensorsNames.size(); i++) { otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); - PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); + PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), + m_ValidateModelFilter->GetMapOfClasses(i)); } // 2. Evaluate the metrics against the validation data // Here we just change the input sources and references - for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) + for (unsigned int i = 0; i < m_InputSourcesForEvaluationAgainstValidationData.size(); i++) { m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); } @@ -510,29 +524,28 @@ public: AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); m_ValidateModelFilter->Update(); - for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) + for (unsigned int i = 0; i < m_TargetTensorsNames.size(); i++) { otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); - PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); + PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), + m_ValidateModelFilter->GetMapOfClasses(i)); } } // Step is OK to perform validation - } // Do the validation against the validation data + } // Do the validation against the validation data - } // Next epoch + } // Next epoch // Check if we have to save variables to somewhere if (HasValue("model.saveto")) - { + { const std::string path = GetParameterAsString("model.saveto"); otbAppLogINFO("Saving model to " + path); tf::SaveModel(path, m_SavedModel); - } - + } } private: - - tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application ! + tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application ! // Filters TrainModelFilterType::Pointer m_TrainModelFilter; @@ -542,9 +555,9 @@ private: BundleList m_Bundles; // Patches size - SizeList m_InputPatchesSizeForTraining; - SizeList m_InputPatchesSizeForValidation; - SizeList m_TargetPatchesSize; + SizeList m_InputPatchesSizeForTraining; + SizeList m_InputPatchesSizeForValidation; + SizeList m_TargetPatchesSize; // Placeholders and Tensors names StringList m_InputPlaceholdersForTraining; @@ -560,7 +573,7 @@ private: }; // end of class -} // namespace wrapper +} // namespace Wrapper } // namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::TensorflowModelTrain ) +OTB_APPLICATION_EXPORT(otb::Wrapper::TensorflowModelTrain) diff --git a/app/otbTrainClassifierFromDeepFeatures.cxx b/app/otbTrainClassifierFromDeepFeatures.cxx index ae1e5d949bbb97d2c8587299d16d9b46f1a12e6a..cc3ec9edce5e54d484fa90751b632d50d203eb4b 100644 --- a/app/otbTrainClassifierFromDeepFeatures.cxx +++ b/app/otbTrainClassifierFromDeepFeatures.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -33,23 +34,23 @@ class TrainClassifierFromDeepFeatures : public CompositeApplication { public: /** Standard class typedefs. */ - typedef TrainClassifierFromDeepFeatures Self; - typedef Application Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef TrainClassifierFromDeepFeatures Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); itkTypeMacro(TrainClassifierFromDeepFeatures, otb::Wrapper::CompositeApplication); private: - // // Add an input source, which includes: // -an input image list // -an input patchsize (dimensions of samples) // - void AddAnInputImage(int inputNumber = 0) + void + AddAnInputImage(int inputNumber = 0) { inputNumber++; @@ -60,75 +61,83 @@ private: // Populate group ShareParameter(ss_key_group.str(), "tfmodel." + ss_key_group.str(), ss_desc_group.str()); - } - void DoInit() + void + DoInit() { - SetName("TrainClassifierFromDeepFeatures"); - SetDescription("Train a classifier from deep net based features of an image and training vector data."); - - // Documentation - SetDocLongDescription("See TrainImagesClassifier application"); - SetDocLimitations("None"); - SetDocAuthors("Remi Cresson"); - SetDocSeeAlso(" "); - - AddDocTag(Tags::Learning); - - ClearApplications(); - - // Add applications - AddApplication("TrainImagesClassifier", "train", "Train images classifier"); - AddApplication("TensorflowModelServe", "tfmodel", "Serve the TF model"); - - // Model shared parameters - AddAnInputImage(); - for (int i = 1; i < tf::GetNumberOfSources() ; i++) - { - AddAnInputImage(i); - } - ShareParameter("model", "tfmodel.model", "Deep net inputs parameters", "Parameters of the deep net inputs: placeholder names, receptive fields, etc."); - ShareParameter("output", "tfmodel.output", "Deep net outputs parameters", "Parameters of the deep net outputs: tensors names, expression fields, etc."); - ShareParameter("optim", "tfmodel.optim", "Processing time optimization", "This group of parameters allows optimization of processing time"); - - // Train shared parameters - ShareParameter("ram", "train.ram", "Available RAM (Mb)", "Available RAM (Mb)"); - ShareParameter("vd", "train.io.vd", "Vector data for training", "Input vector data for training"); - ShareParameter("valid", "train.io.valid", "Vector data for validation", "Input vector data for validation"); - ShareParameter("out", "train.io.out", "Output classification model", "Output classification model"); - ShareParameter("confmatout", "train.io.confmatout", "Output confusion matrix", "Output confusion matrix of the classification model"); - - // Shared parameter groups - ShareParameter("sample", "train.sample", "Sampling parameters" , "Training and validation samples parameters" ); - ShareParameter("elev", "train.elev", "Elevation parameters", "Elevation parameters" ); - ShareParameter("classifier", "train.classifier", "Classifier parameters", "Classifier parameters" ); - ShareParameter("rand", "train.rand", "User defined random seed", "User defined random seed" ); - + SetName("TrainClassifierFromDeepFeatures"); + SetDescription("Train a classifier from deep net based features of an image and training vector data."); + + // Documentation + SetDocLongDescription("See TrainImagesClassifier application"); + SetDocLimitations("None"); + SetDocAuthors("Remi Cresson"); + SetDocSeeAlso(" "); + + AddDocTag(Tags::Learning); + + ClearApplications(); + + // Add applications + AddApplication("TrainImagesClassifier", "train", "Train images classifier"); + AddApplication("TensorflowModelServe", "tfmodel", "Serve the TF model"); + + // Model shared parameters + AddAnInputImage(); + for (int i = 1; i < tf::GetNumberOfSources(); i++) + { + AddAnInputImage(i); + } + ShareParameter("model", + "tfmodel.model", + "Deep net inputs parameters", + "Parameters of the deep net inputs: placeholder names, receptive fields, etc."); + ShareParameter("output", + "tfmodel.output", + "Deep net outputs parameters", + "Parameters of the deep net outputs: tensors names, expression fields, etc."); + ShareParameter("optim", + "tfmodel.optim", + "Processing time optimization", + "This group of parameters allows optimization of processing time"); + + // Train shared parameters + ShareParameter("ram", "train.ram", "Available RAM (Mb)", "Available RAM (Mb)"); + ShareParameter("vd", "train.io.vd", "Vector data for training", "Input vector data for training"); + ShareParameter("valid", "train.io.valid", "Vector data for validation", "Input vector data for validation"); + ShareParameter("out", "train.io.out", "Output classification model", "Output classification model"); + ShareParameter("confmatout", + "train.io.confmatout", + "Output confusion matrix", + "Output confusion matrix of the classification model"); + + // Shared parameter groups + ShareParameter("sample", "train.sample", "Sampling parameters", "Training and validation samples parameters"); + ShareParameter("elev", "train.elev", "Elevation parameters", "Elevation parameters"); + ShareParameter("classifier", "train.classifier", "Classifier parameters", "Classifier parameters"); + ShareParameter("rand", "train.rand", "User defined random seed", "User defined random seed"); } - void DoUpdateParameters() + void + DoUpdateParameters() { UpdateInternalParameters("train"); } - void DoExecute() + void + DoExecute() { ExecuteInternal("tfmodel"); - GetInternalApplication("train")->AddImageToParameterInputImageList("io.il", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); + GetInternalApplication("train")->AddImageToParameterInputImageList( + "io.il", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); UpdateInternalParameters("train"); ExecuteInternal("train"); - } // DOExecute() - - void AfterExecuteAndWriteOutputs() - { - // Nothing to do } - }; } // namespace Wrapper } // namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::TrainClassifierFromDeepFeatures ) +OTB_APPLICATION_EXPORT(otb::Wrapper::TrainClassifierFromDeepFeatures) diff --git a/doc/APPLICATIONS.md b/doc/APPLICATIONS.md index 69282cd543576f66a790f9619f6f75b31b4e1514..25d739fb863e3939d8bb4cbdbd0fe4050766f10a 100644 --- a/doc/APPLICATIONS.md +++ b/doc/APPLICATIONS.md @@ -59,7 +59,7 @@ When using a model in OTBTF, the important thing is to know the following parame  -The **scale factor** descibes the physical change of spacing of the outputs, typically introduced in the model by non unitary strides in pooling or convolution operators. +The **scale factor** describes the physical change of spacing of the outputs, typically introduced in the model by non unitary strides in pooling or convolution operators. For each output, it is expressed relatively to one single input of the model called the _reference input source_. Additionally, the names of the _target nodes_ must be known (e.g. "optimizer"). Also, the names of _user placeholders_, typically scalars placeholders that are used to control some parameters of the model, must be know (e.g. "dropout_rate"). @@ -356,7 +356,7 @@ But here, we will just perform some fine tuning of our model. The **SavedModel** is located in the `outmodel` directory. Our model is quite basic: it has two input placeholders, **x1** and **y1** respectively for input patches (with size 16x16) and input reference labels (with size 1x1). We named **prediction** the tensor that predict the labels and the optimizer that perform the stochastic gradient descent is an operator named **optimizer**. -We perform the fine tuning and we export the new model variables directly in the _outmodel/variables_ folder, overwritting the existing variables of the model. +We perform the fine tuning and we export the new model variables directly in the _outmodel/variables_ folder, overwriting the existing variables of the model. We use the **TensorflowModelTrain** application to perform the training of this existing model. ``` otbcli_TensorflowModelTrain -model.dir /path/to/oursavedmodel -training.targetnodesnames optimizer -training.source1.il samp_patches.tif -training.source1.patchsizex 16 -training.source1.patchsizey 16 -training.source1.placeholder x1 -training.source2.il samp_labels.tif -training.source2.patchsizex 1 -training.source2.patchsizey 1 -training.source2.placeholder y1 -model.saveto /path/to/oursavedmodel/variables/variables diff --git a/doc/DOCKERUSE.md b/doc/DOCKERUSE.md index b510877751d622bbf11b0410c336b2d41dffc654..0f1086015a59c46efd5a1ca3f93a61dd962fa75f 100644 --- a/doc/DOCKERUSE.md +++ b/doc/DOCKERUSE.md @@ -4,28 +4,33 @@ Here is the list of OTBTF docker images hosted on [dockerhub](https://hub.docker.com/u/mdl4eo). -| Name | Os | TF | OTB | Description | -| ----------------------------- | ------------- | ------ | ----- | ---------------------- | -| **mdl4eo/otbtf1.6:cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | -| **mdl4eo/otbtf1.7:cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | -| **mdl4eo/otbtf1.7:gpu** | Ubuntu Xenial | r1.14 | 7.0.0 | GPU | -| **mdl4eo/otbtf2.0:cpu** | Ubuntu Xenial | r2.1 | 7.1.0 | CPU, no optimization | -| **mdl4eo/otbtf2.0:gpu** | Ubuntu Xenial | r2.1 | 7.1.0 | GPU | -| **mdl4eo/otbtf2.4:cpu-basic** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, no optimization | -| **mdl4eo/otbtf2.4:cpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, few optimizations | -| **mdl4eo/otbtf2.4:cpu-mkl** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, Intel MKL, AVX512 | -| **mdl4eo/otbtf2.4:gpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | GPU | +| Name | Os | TF | OTB | Description | Dev files | Compute capability | +| --------------------------------- | ------------- | ------ | ----- | ---------------------- | --------- | ------------------ | +| **mdl4eo/otbtf1.6:cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | yes | 5.2,6.1,7.0 | +| **mdl4eo/otbtf1.7:cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | yes | 5.2,6.1,7.0 | +| **mdl4eo/otbtf1.7:gpu** | Ubuntu Xenial | r1.14 | 7.0.0 | GPU | yes | 5.2,6.1,7.0 | +| **mdl4eo/otbtf2.0:cpu** | Ubuntu Xenial | r2.1 | 7.1.0 | CPU, no optimization | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf2.0:gpu** | Ubuntu Xenial | r2.1 | 7.1.0 | GPU | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf2.4:cpu-basic** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, no optimization | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf2.4:cpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, few optimizations | no | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf2.4:cpu-mkl** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, Intel MKL, AVX512 | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf2.4:gpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | GPU | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf2.5:cpu-basic** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf2.5:cpu-basic-dev** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf2.5:cpu** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, few optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf2.5:gpu** | Ubuntu Focal | r2.5 | 7.4.0 | GPU | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf2.5:gpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| - `cpu` tagged docker images are compiled without optimization. -- `gpu` tagged docker images are suited for **NVIDIA GPUs**. They use CUDA/CUDNN support and are built with compute capabilities 5.2, 6.1, 7.0, 7.5. +- `gpu` tagged docker images are suited for **NVIDIA GPUs**. They use CUDA/CUDNN support. - `cpu-mkl` tagged docker image is experimental, it is optimized for Intel CPUs with AVX512 flags. -You can also find plenty of interesting OTBTF flavored images at [LaTelescop gitlab registry](https://gitlab.com/latelescop/docker/otbtf/container_registry/). +You can also find more interesting OTBTF flavored images at [LaTelescop gitlab registry](https://gitlab.com/latelescop/docker/otbtf/container_registry/). ### Development ready images -Until r2.4, all images are development-ready. For instance, you can recompile the whole OTB from `/work/otb/build/OTB/build`. -Since r2.4, only `gpu` tagged image is development-ready, and you can recompile OTB from `/src/otb/build/OTB/build`. +Until r2.4, all images are development-ready, and the sources are located in `/work/`. +Since r2.4, development-ready images have the source in `/src/`. ### Build your own images @@ -40,7 +45,7 @@ For instance, suppose you have some data in `/mnt/my_device/` that you want to u The following command shows you how to access the folder from the docker image. ```bash -docker run -v /mnt/my_device/:/data/ -ti mdl4eo/otbtf2.4:cpu bash -c "ls /data" +docker run -v /mnt/my_device/:/data/ -ti mdl4eo/otbtf2.5:cpu bash -c "ls /data" ``` Beware of ownership issues! see the last section of this doc. @@ -53,13 +58,13 @@ You can then use the OTBTF `gpu` tagged docker images with the **NVIDIA runtime* With Docker version earlier than 19.03 : ```bash -docker run --runtime=nvidia -ti mdl4eo/otbtf2.4:gpu bash +docker run --runtime=nvidia -ti mdl4eo/otbtf2.5:gpu bash ``` With Docker version including and after 19.03 : ```bash -docker run --gpus all -ti mdl4eo/otbtf2.4:gpu bash +docker run --gpus all -ti mdl4eo/otbtf2.5:gpu bash ``` You can find some details on the **GPU docker image** and some **docker tips and tricks** on [this blog](https://mdl4eo.irstea.fr/2019/10/15/otbtf-docker-image-with-gpu/). @@ -72,7 +77,7 @@ Be careful though, these infos might be a bit outdated... 1. Install [WSL2](https://docs.microsoft.com/en-us/windows/wsl/install-win10#manual-installation-steps) (Windows Subsystem for Linux) 2. Install [docker desktop](https://www.docker.com/products/docker-desktop) 3. Start **docker desktop** and **enable WSL2** from *Settings* > *General* then tick the box *Use the WSL2 based engine* -3. Open a **cmd.exe** or **PowerShell** terminal, and type `docker create --name otbtf-cpu --interactive --tty mdl4eo/otbtf2.4:cpu` +3. Open a **cmd.exe** or **PowerShell** terminal, and type `docker create --name otbtf-cpu --interactive --tty mdl4eo/otbtf2.5:cpu` 4. Open **docker desktop**, and check that the docker is running in the **Container/Apps** menu  5. From **docker desktop**, click on the icon highlighted as shown below, and use the bash terminal that should pop up! @@ -102,7 +107,7 @@ This section is largely inspired from the [moringa docker help](https://gitlab.i ## Useful diagnostic commands -Here are some usefull commands. +Here are some useful commands. ```bash docker info # System info @@ -121,12 +126,12 @@ sudo systemctl {status,enable,disable,start,stop} docker Run a simple command in a one-shot container: ```bash -docker run mdl4eo/otbtf2.4:cpu otbcli_PatchesExtraction +docker run mdl4eo/otbtf2.5:cpu otbcli_PatchesExtraction ``` You can also use the image in interactive mode with bash: ```bash -docker run -ti mdl4eo/otbtf2.4:cpu bash +docker run -ti mdl4eo/otbtf2.5:cpu bash ``` ### Persistent container @@ -136,7 +141,7 @@ Beware of ownership issues, see the last section of this doc. ```bash docker create --interactive --tty --volume /home/$USER:/home/otbuser/ \ - --name otbtf mdl4eo/otbtf2.4:cpu /bin/bash + --name otbtf mdl4eo/otbtf2.5:cpu /bin/bash ``` ### Interactive session @@ -200,7 +205,7 @@ Create a named container (here with your HOME as volume), Docker will automatica ```bash docker create --interactive --tty --volume /home/$USER:/home/otbuser \ - --name otbtf mdl4eo/otbtf2.4:cpu /bin/bash + --name otbtf mdl4eo/otbtf2.5:cpu /bin/bash ``` Start a background container process: diff --git a/include/otbTensorflowCommon.cxx b/include/otbTensorflowCommon.cxx index b93717ed3aa5806f364ebcaa1d73edfa25c45f6b..b7a27c60c5ef49fbee42556ace70b54751f682f9 100644 --- a/include/otbTensorflowCommon.cxx +++ b/include/otbTensorflowCommon.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -11,8 +11,10 @@ =========================================================================*/ #include "otbTensorflowCommon.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // // Environment variable for the number of sources in "Multisource" applications @@ -22,21 +24,21 @@ const std::string ENV_VAR_NAME_NSOURCES = "OTB_TF_NSOURCES"; // // Get the environment variable as int // -int GetEnvironmentVariableAsInt(const std::string & variableName) +int +GetEnvironmentVariableAsInt(const std::string & variableName) { - int ret = -1; - char const* tmp = getenv( variableName.c_str() ); - if ( tmp != NULL ) + int ret = -1; + char const * tmp = getenv(variableName.c_str()); + if (tmp != NULL) { - std::string s( tmp ); + std::string s(tmp); try { ret = std::stoi(s); } - catch(...) + catch (...) { - itkGenericExceptionMacro("Error parsing variable " - << variableName << " as integer. Value is " << s); + itkGenericExceptionMacro("Error parsing variable " << variableName << " as integer. Value is " << s); } } @@ -47,7 +49,8 @@ int GetEnvironmentVariableAsInt(const std::string & variableName) // This function returns the numeric content of the ENV_VAR_NAME_NSOURCES // environment variable // -int GetNumberOfSources() +int +GetNumberOfSources() { int ret = GetEnvironmentVariableAsInt(ENV_VAR_NAME_NSOURCES); if (ret != -1) @@ -60,15 +63,18 @@ int GetNumberOfSources() // // This function copy a patch from an input image to an output image // -template<class TImage> -void CopyPatch(typename TImage::Pointer inputImg, typename TImage::IndexType & inputPatchIndex, - typename TImage::Pointer outputImg, typename TImage::IndexType & outputPatchIndex, - typename TImage::SizeType patchSize) +template <class TImage> +void +CopyPatch(typename TImage::Pointer inputImg, + typename TImage::IndexType & inputPatchIndex, + typename TImage::Pointer outputImg, + typename TImage::IndexType & outputPatchIndex, + typename TImage::SizeType patchSize) { - typename TImage::RegionType inputPatchRegion(inputPatchIndex, patchSize); - typename TImage::RegionType outputPatchRegion(outputPatchIndex, patchSize); - typename itk::ImageRegionConstIterator<TImage> inIt (inputImg, inputPatchRegion); - typename itk::ImageRegionIterator<TImage> outIt (outputImg, outputPatchRegion); + typename TImage::RegionType inputPatchRegion(inputPatchIndex, patchSize); + typename TImage::RegionType outputPatchRegion(outputPatchIndex, patchSize); + typename itk::ImageRegionConstIterator<TImage> inIt(inputImg, inputPatchRegion); + typename itk::ImageRegionIterator<TImage> outIt(outputImg, outputPatchRegion); for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++outIt) { outIt.Set(inIt.Get()); @@ -78,9 +84,9 @@ void CopyPatch(typename TImage::Pointer inputImg, typename TImage::IndexType & i // // Get image infos // -template<class TImage> -void GetImageInfo(typename TImage::Pointer image, - unsigned int & sizex, unsigned int & sizey, unsigned int & nBands) +template <class TImage> +void +GetImageInfo(typename TImage::Pointer image, unsigned int & sizex, unsigned int & sizey, unsigned int & nBands) { nBands = image->GetNumberOfComponentsPerPixel(); sizex = image->GetLargestPossibleRegion().GetSize(0); @@ -90,8 +96,9 @@ void GetImageInfo(typename TImage::Pointer image, // // Propagate the requested region in the image // -template<class TImage> -void PropagateRequestedRegion(typename TImage::Pointer image, typename TImage::RegionType & region) +template <class TImage> +void +PropagateRequestedRegion(typename TImage::Pointer image, typename TImage::RegionType & region) { image->SetRequestedRegion(region); image->PropagateRequestedRegion(); @@ -101,13 +108,16 @@ void PropagateRequestedRegion(typename TImage::Pointer image, typename TImage::R // // Sample an input image at the specified location // -template<class TImage> -bool SampleImage(const typename TImage::Pointer inPtr, typename TImage::Pointer outPtr, - typename TImage::PointType point, unsigned int elemIdx, - typename TImage::SizeType patchSize) +template <class TImage> +bool +SampleImage(const typename TImage::Pointer inPtr, + typename TImage::Pointer outPtr, + typename TImage::PointType point, + unsigned int elemIdx, + typename TImage::SizeType patchSize) { typename TImage::IndexType index, outIndex; - bool canTransform = inPtr->TransformPhysicalPointToIndex(point, index); + bool canTransform = inPtr->TransformPhysicalPointToIndex(point, index); if (canTransform) { outIndex[0] = 0; @@ -128,7 +138,6 @@ bool SampleImage(const typename TImage::Pointer inPtr, typename TImage::Pointer } } return false; - } } // end namespace tf diff --git a/include/otbTensorflowCommon.h b/include/otbTensorflowCommon.h index e8db4b8b361e9a1277f96ce50aba4027ffc5d251..a012173c66ec9c4a10ae0f5f7df9908f01dd4833 100644 --- a/include/otbTensorflowCommon.h +++ b/include/otbTensorflowCommon.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,39 +18,53 @@ #include <string> #include <algorithm> #include <functional> +#include "itkMacro.h" +#include "itkImageRegionConstIterator.h" +#include "itkImageRegionIterator.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // Environment variable for the number of sources in "Multisource" applications extern const std::string ENV_VAR_NAME_NSOURCES; // Get the environment variable as int -int GetEnvironmentVariableAsInt(const std::string & variableName); +int +GetEnvironmentVariableAsInt(const std::string & variableName); // Get the value (as int) of the environment variable ENV_VAR_NAME_NSOURCES -int GetNumberOfSources(); +int +GetNumberOfSources(); // This function copy a patch from an input image to an output image -template<class TImage> -void CopyPatch(typename TImage::Pointer inputImg, typename TImage::IndexType & inputPatchIndex, - typename TImage::Pointer outputImg, typename TImage::IndexType & outputPatchIndex, - typename TImage::SizeType patchSize); +template <class TImage> +void +CopyPatch(typename TImage::Pointer inputImg, + typename TImage::IndexType & inputPatchIndex, + typename TImage::Pointer outputImg, + typename TImage::IndexType & outputPatchIndex, + typename TImage::SizeType patchSize); // Get image infos -template<class TImage> -void GetImageInfo(typename TImage::Pointer image, - unsigned int & sizex, unsigned int & sizey, unsigned int & nBands); +template <class TImage> +void +GetImageInfo(typename TImage::Pointer image, unsigned int & sizex, unsigned int & sizey, unsigned int & nBands); // Propagate the requested region in the image -template<class TImage> -void PropagateRequestedRegion(typename TImage::Pointer image, typename TImage::RegionType & region); +template <class TImage> +void +PropagateRequestedRegion(typename TImage::Pointer image, typename TImage::RegionType & region); // Sample an input image at the specified location -template<class TImage> -bool SampleImage(const typename TImage::Pointer inPtr, typename TImage::Pointer outPtr, - typename TImage::PointType point, unsigned int elemIdx, - typename TImage::SizeType patchSize); +template <class TImage> +bool +SampleImage(const typename TImage::Pointer inPtr, + typename TImage::Pointer outPtr, + typename TImage::PointType point, + unsigned int elemIdx, + typename TImage::SizeType patchSize); } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowCopyUtils.cxx b/include/otbTensorflowCopyUtils.cxx index 116aafb0b5fb1fe5b00f830291c0218fdb3f6bdf..b2a6e70e3a44ba6141e96f38b2618679f9d007d4 100644 --- a/include/otbTensorflowCopyUtils.cxx +++ b/include/otbTensorflowCopyUtils.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -11,27 +11,31 @@ =========================================================================*/ #include "otbTensorflowCopyUtils.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // // Display a TensorShape // -std::string PrintTensorShape(const tensorflow::TensorShape & shp) +std::string +PrintTensorShape(const tensorflow::TensorShape & shp) { std::stringstream s; - unsigned int nDims = shp.dims(); + unsigned int nDims = shp.dims(); s << "{" << shp.dim_size(0); - for (unsigned int d = 1 ; d < nDims ; d++) + for (unsigned int d = 1; d < nDims; d++) s << ", " << shp.dim_size(d); - s << "}" ; + s << "}"; return s.str(); } // // Display infos about a tensor // -std::string PrintTensorInfos(const tensorflow::Tensor & tensor) +std::string +PrintTensorInfos(const tensorflow::Tensor & tensor) { std::stringstream s; s << "Tensor "; @@ -39,17 +43,19 @@ std::string PrintTensorInfos(const tensorflow::Tensor & tensor) s << "shape is " << PrintTensorShape(tensor.shape()); // Data type s << " data type is " << tensor.dtype(); + s << " (" << tf::GetDataTypeAsString(tensor.dtype()) << ")"; return s.str(); } // // Create a tensor with the good datatype // -template<class TImage> -tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape) +template <class TImage> +tensorflow::Tensor +CreateTensor(tensorflow::TensorShape & shape) { tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>(); - tensorflow::Tensor out_tensor(ts_dt, shape); + tensorflow::Tensor out_tensor(ts_dt, shape); return out_tensor; } @@ -58,32 +64,35 @@ tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape) // Populate a tensor with the buffered region of a vector image using std::copy // Warning: tensor datatype must be consistent with the image value type // -template<class TImage> -void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor) +template <class TImage> +void +PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor) { - size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() * - bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels(); - std::copy_n(bufferedimagePtr->GetBufferPointer(), - n_elem, - out_tensor.flat<typename TImage::InternalPixelType>().data()); + size_t n_elem = + bufferedimagePtr->GetNumberOfComponentsPerPixel() * bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels(); + std::copy_n( + bufferedimagePtr->GetBufferPointer(), n_elem, out_tensor.flat<typename TImage::InternalPixelType>().data()); } // // Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands}) // -template<class TImage, class TValueType=typename TImage::InternalPixelType> -void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, - tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension +template <class TImage, class TValueType = typename TImage::InternalPixelType> +void +RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx) // element position along the 1st dimension { typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region); - unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); - auto tMap = tensor.tensor<TValueType, 4>(); + unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); + auto tMap = tensor.tensor<TValueType, 4>(); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { const int y = inIt.GetIndex()[1] - region.GetIndex()[1]; const int x = inIt.GetIndex()[0] - region.GetIndex()[0]; - for (unsigned int band = 0 ; band < nBands ; band++) + for (unsigned int band = 0; band < nBands; band++) tMap(elemIdx, y, x, band) = inIt.Get()[band]; } } @@ -92,9 +101,12 @@ void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const ty // Type-agnostic version of the 'RecopyImageRegionToTensor' function // TODO: add some numeric types // -template<class TImage> -void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, - tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension +template <class TImage> +void +RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx) // element position along the 1st dimension { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) @@ -110,21 +122,25 @@ void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, else if (dt == tensorflow::DT_INT32) RecopyImageRegionToTensor<TImage, int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT16) - RecopyImageRegionToTensor<TImage, unsigned short int> (inputPtr, region, tensor, elemIdx); + RecopyImageRegionToTensor<TImage, unsigned short int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT16) RecopyImageRegionToTensor<TImage, short int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT8) - RecopyImageRegionToTensor<TImage, unsigned char> (inputPtr, region, tensor, elemIdx); + RecopyImageRegionToTensor<TImage, unsigned char>(inputPtr, region, tensor, elemIdx); else - itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); + itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !"); } // // Sample a centered patch (from index) // -template<class TImage> -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize, - tensorflow::Tensor & tensor, unsigned int elemIdx) +template <class TImage> +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::IndexType & centerIndex, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx) { typename TImage::IndexType regionStart; regionStart[0] = centerIndex[0] - patchSize[0] / 2; @@ -136,9 +152,13 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename // // Sample a centered patch (from coordinates) // -template<class TImage> -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, - tensorflow::Tensor & tensor, unsigned int elemIdx) +template <class TImage> +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::PointType & centerCoord, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx) { // Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands} // Get the index of the center @@ -147,41 +167,48 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename SampleCenteredPatch<TImage>(inputPtr, centerIndex, patchSize, tensor, elemIdx); } -// Return the number of channels that the output tensor will occupy in the output image // +// Return the number of channels from the TensorShapeProto // shape {n} --> 1 (e.g. a label) -// shape {n, c} --> c (e.g. a vector) -// shape {x, y, c} --> c (e.g. a patch) -// shape {n, x, y, c} --> c (e.g. some patches) +// shape {n, c} --> c (e.g. a pixel) +// shape {n, x, y} --> 1 (e.g. a mono-channel patch) +// shape {n, x, y, c} --> c (e.g. a multi-channel patch) // -tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor) +tensorflow::int64 +GetNumberOfChannelsFromShapeProto(const tensorflow::TensorShapeProto & proto) { - const tensorflow::TensorShape shape = tensor.shape(); - const int nDims = shape.dims(); + const int nDims = proto.dim_size(); if (nDims == 1) + // e.g. a batch prediction, as flat tensor + return 1; + if (nDims == 3) + // typically when the last dimension in squeezed following a + // computation that does not keep dimensions (e.g. reduce_sum, etc.) return 1; - return shape.dim_size(nDims - 1); + // any other dimension: we assume that the last dimension represent the + // number of channels in the output image. + return proto.dim(nDims - 1).size(); } // // Copy a tensor into the image region -// TODO: Enable to change mapping from source tensor to image to make it more generic -// -// Right now, only the following output tensor shapes can be processed: -// shape {n} --> 1 (e.g. a label) -// shape {n, c} --> c (e.g. a vector) -// shape {x, y, c} --> c (e.g. a multichannel image) // -template<class TImage, class TValueType> -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, - typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset) +template <class TImage, class TValueType> +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + const typename TImage::RegionType & bufferRegion, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & outputRegion, + int & channelOffset) { // Flatten the tensor auto tFlat = tensor.flat<TValueType>(); - // Get the size of the last component of the tensor (see 'GetNumberOfChannelsForOutputTensor(...)') - const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsForOutputTensor(tensor); + // Get the number of component of the output image + tensorflow::TensorShapeProto proto; + tensor.shape().AsProto(&proto); + const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsFromShapeProto(proto); // Number of columns (size x of the buffer) const tensorflow::int64 nCols = bufferRegion.GetSize(0); @@ -191,15 +218,16 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C; if (nElmI != nElmT) { - itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT << - " but image outputRegion has " << nElmI << - " values to fill.\nBuffer region:\n" << bufferRegion << - "\nNumber of components: " << outputDimSize_C << - "\nTensor shape:\n " << PrintTensorShape(tensor.shape()) << - "\nPlease check the input(s) field of view (FOV), " << - "the output field of expression (FOE), and the " << - "output spacing scale if you run the model in fully " << - "convolutional mode (how many strides in your model?)"); + itkGenericExceptionMacro("Number of elements in the tensor is " + << nElmT << " but image outputRegion has " << nElmI << " values to fill.\n" + << "Buffer region is: \n" + << bufferRegion << "\n" + << "Number of components in the output image: " << outputDimSize_C << "\n" + << "Tensor shape: " << PrintTensorShape(tensor.shape()) << "\n" + << "Please check the input(s) field of view (FOV), " + << "the output field of expression (FOE), and the " + << "output spacing scale if you run the model in fully " + << "convolutional mode (how many strides in your model?)"); } // Iterate over the image @@ -212,145 +240,217 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T // TODO: it could be useful to change the tensor-->image mapping here. // e.g use a lambda for "pos" calculation const int pos = outputDimSize_C * (y * nCols + x); - for (unsigned int c = 0 ; c < outputDimSize_C ; c++) - outIt.Get()[channelOffset + c] = tFlat( pos + c); + for (unsigned int c = 0; c < outputDimSize_C; c++) + outIt.Get()[channelOffset + c] = tFlat(pos + c); } // Update the offset channelOffset += outputDimSize_C; - } // // Type-agnostic version of the 'CopyTensorToImageRegion' function -// TODO: add some numeric types // -template<class TImage> -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, - typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset) +template <class TImage> +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + const typename TImage::RegionType & bufferRegion, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & region, + int & channelOffset) { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) - CopyTensorToImageRegion<TImage, float> (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion<TImage, float>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_DOUBLE) - CopyTensorToImageRegion<TImage, double> (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion<TImage, double>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT64) + CopyTensorToImageRegion<TImage, unsigned long long int>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT64) CopyTensorToImageRegion<TImage, long long int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT32) + CopyTensorToImageRegion<TImage, unsigned int>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT32) - CopyTensorToImageRegion<TImage, int> (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion<TImage, int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT16) + CopyTensorToImageRegion<TImage, unsigned short int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_INT16) + CopyTensorToImageRegion<TImage, short int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT8) + CopyTensorToImageRegion<TImage, unsigned char>(tensor, bufferRegion, outputPtr, region, channelOffset); else - itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); - + itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !"); } // // Compare two string lowercase // -bool iequals(const std::string& a, const std::string& b) +bool +iequals(const std::string & a, const std::string & b) { - return std::equal(a.begin(), a.end(), - b.begin(), b.end(), - [](char cha, char chb) { - return tolower(cha) == tolower(chb); - }); + return std::equal( + a.begin(), a.end(), b.begin(), b.end(), [](char cha, char chb) { return tolower(cha) == tolower(chb); }); } -// Convert an expression into a dict -// +// Convert a value into a tensor // Following types are supported: // -bool // -int // -float +// -vector of float +// +// e.g. "true", "0.2", "14", "(1.2, 4.2, 4)" // -// e.g. is_training=true, droptout=0.2, nfeat=14 -std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression) +// TODO: we could add some other types (e.g. string) +tensorflow::Tensor +ValueToTensor(std::string value) { - std::pair<std::string, tensorflow::Tensor> dict; + std::vector<std::string> values; - std::size_t found = expression.find("="); - if (found != std::string::npos) - { - // Find name and value - std::string name = expression.substr(0, found); - std::string value = expression.substr(found+1); + // Check if value is a vector or a scalar + const bool has_left = (value[0] == '('); + const bool has_right = value[value.size() - 1] == ')'; - dict.first = name; + // Check consistency + bool is_vec = false; + if (has_left || has_right) + { + is_vec = true; + if (!has_left || !has_right) + itkGenericExceptionMacro("Error parsing vector expression (missing parentheses ?)" << value); + } + + // Scalar --> Vector for generic processing + if (!is_vec) + { + values.push_back(value); + } + else + { + // Remove "(" and ")" chars + std::string trimmed_value = value.substr(1, value.size() - 2); + + // Split string into vector using "," delimiter + std::regex rgx("\\s*,\\s*"); + std::sregex_token_iterator iter{ trimmed_value.begin(), trimmed_value.end(), rgx, -1 }; + std::sregex_token_iterator end; + values = std::vector<std::string>({ iter, end }); + } + + // Find type + bool has_dot = false; + bool is_digit = true; + for (auto & val : values) + { + has_dot = has_dot || val.find(".") != std::string::npos; + is_digit = is_digit && val.find_first_not_of("-0123456789.") == std::string::npos; + } + + // Create tensor + tensorflow::TensorShape shape({ values.size() }); + tensorflow::Tensor out(tensorflow::DT_BOOL, shape); + if (is_digit) + { + if (has_dot) + out = tensorflow::Tensor(tensorflow::DT_FLOAT, shape); + else + out = tensorflow::Tensor(tensorflow::DT_INT32, shape); + } + + // Fill tensor + unsigned int idx = 0; + for (auto & val : values) + { - // Find type - std::size_t found_dot = value.find(".") != std::string::npos; - std::size_t is_digit = value.find_first_not_of("0123456789.") == std::string::npos; - if (is_digit) + if (is_digit) + { + if (has_dot) { - if (found_dot) + // FLOAT + try { - // FLOAT - try - { - float val = std::stof(value); - tensorflow::Tensor out(tensorflow::DT_FLOAT, tensorflow::TensorShape()); - out.scalar<float>()() = val; - dict.second = out; - - } - catch(...) - { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as float"); - } - + out.flat<float>()(idx) = std::stof(val); } - else + catch (...) { - // INT - try - { - int val = std::stoi(value); - tensorflow::Tensor out(tensorflow::DT_INT32, tensorflow::TensorShape()); - out.scalar<int>()() = val; - dict.second = out; - - } - catch(...) - { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as int"); - } - + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as float"); } } else { - // BOOL - bool val = true; - if (iequals(value, "true")) - { - val = true; - } - else if (iequals(value, "false")) + // INT + try { - val = false; + out.flat<int>()(idx) = std::stoi(val); } - else + catch (...) { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as bool"); + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as int"); } - tensorflow::Tensor out(tensorflow::DT_BOOL, tensorflow::TensorShape()); - out.scalar<bool>()() = val; - dict.second = out; } - } else { - itkGenericExceptionMacro("The following expression is not valid: " - << "\n\t" << expression - << ".\nExpression must be in the form int_value=1 or float_value=1.0 or bool_value=true."); + // BOOL + bool ret = true; + if (iequals(val, "true")) + { + ret = true; + } + else if (iequals(val, "false")) + { + ret = false; + } + else + { + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as bool"); + } + out.flat<bool>()(idx) = ret; } + idx++; + } + otbLogMacro(Debug, << "Returning tensor: " << out.DebugString()); + + return out; +} + +// Convert an expression into a dict +// +// Following types are supported: +// -bool +// -int +// -float +// -vector of float +// +// e.g. is_training=true, droptout=0.2, nfeat=14, x=(1.2, 4.2, 4) +std::pair<std::string, tensorflow::Tensor> +ExpressionToTensor(std::string expression) +{ + std::pair<std::string, tensorflow::Tensor> dict; - return dict; + std::size_t found = expression.find("="); + if (found != std::string::npos) + { + // Find name and value + std::string name = expression.substr(0, found); + std::string value = expression.substr(found + 1); + + dict.first = name; + + // Transform value into tensorflow::Tensor + dict.second = ValueToTensor(value); + } + else + { + itkGenericExceptionMacro("The following expression is not valid: " + << "\n\t" << expression << ".\nExpression must be in one of the following form:" + << "\n- int32_value=1 \n- float_value=1.0 \n- bool_value=true." + << "\n- float_vec=(1.0, 5.253, 2)"); + } + + return dict; } } // end namespace tf diff --git a/include/otbTensorflowCopyUtils.h b/include/otbTensorflowCopyUtils.h index 47ad6cf2366137ac9719158069ac0d91ffef5813..59e1a7443ff78511b42d6d67c74023cd49864235 100644 --- a/include/otbTensorflowCopyUtils.h +++ b/include/otbTensorflowCopyUtils.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -15,67 +15,113 @@ // ITK exception #include "itkMacro.h" +// OTB log +#include "otbMacro.h" + // ITK image iterators #include "itkImageRegionIterator.h" #include "itkImageRegionConstIterator.h" // tensorflow::tensor #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // tensorflow::datatype <--> ImageType::InternalPixelType #include "otbTensorflowDataTypeBridge.h" // STD #include <string> +#include <regex> -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // Generate a string with TensorShape infos -std::string PrintTensorShape(const tensorflow::TensorShape & shp); +std::string +PrintTensorShape(const tensorflow::TensorShape & shp); // Generate a string with tensor infos -std::string PrintTensorInfos(const tensorflow::Tensor & tensor); +std::string +PrintTensorInfos(const tensorflow::Tensor & tensor); // Create a tensor with the good datatype -template<class TImage> -tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape); +template <class TImage> +tensorflow::Tensor +CreateTensor(tensorflow::TensorShape & shape); // Populate a tensor with the buffered region of a vector image -template<class TImage> -void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor); +template <class TImage> +void +PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor); // Populate the buffered region of a vector image with a given tensor's values -template<class TImage> -void TensorToImageBuffer(const tensorflow::Tensor & tensor, typename TImage::Pointer & image); +template <class TImage> +void +TensorToImageBuffer(const tensorflow::Tensor & tensor, typename TImage::Pointer & image); // Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands}) -template<class TImage, class TValueType=typename TImage::InternalPixelType> -void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx); +template <class TImage, class TValueType = typename TImage::InternalPixelType> +void +RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx); // Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor (TValueType-agnostic function) -template<class TImage> -void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx); +template <class TImage> +void +RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx); // Sample a centered patch -template<class TImage> -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx); -template<class TImage> -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx); - -// Return the number of channels that the output tensor will occupy in the output image -tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor); +template <class TImage> +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::IndexType & centerIndex, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx); +template <class TImage> +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::PointType & centerCoord, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx); + +// Return the number of channels from the TensorflowShapeProto +tensorflow::int64 +GetNumberOfChannelsFromShapeProto(const tensorflow::TensorShapeProto & proto); // Copy a tensor into the image region -template<class TImage, class TValueType> -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset); +template <class TImage, class TValueType> +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & region, + int & channelOffset); // Copy a tensor into the image region (TValueType-agnostic version) -template<class TImage> -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset); +template <class TImage> +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + const typename TImage::RegionType & bufferRegion, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & outputRegion, + int & channelOffset); + +// Convert a value into a tensor +tensorflow::Tensor +ValueToTensor(std::string value); // Convert an expression into a dict -std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression); +std::pair<std::string, tensorflow::Tensor> +ExpressionToTensor(std::string expression); } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowDataTypeBridge.cxx b/include/otbTensorflowDataTypeBridge.cxx index 5a421c92ae0828a2cd39d91e06b6b0eb9aaa9aab..71fcd8c6beca73f611aa919b237c9d719f6ec4a7 100644 --- a/include/otbTensorflowDataTypeBridge.cxx +++ b/include/otbTensorflowDataTypeBridge.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -11,14 +11,17 @@ =========================================================================*/ #include "otbTensorflowDataTypeBridge.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // // returns the datatype used by tensorflow // -template<class Type> -tensorflow::DataType GetTensorflowDataType() +template <class Type> +tensorflow::DataType +GetTensorflowDataType() { if (typeid(Type) == typeid(bool)) { @@ -74,11 +77,21 @@ tensorflow::DataType GetTensorflowDataType() // // Return true if the tensor data type is correct // -template<class Type> -bool HasSameDataType(const tensorflow::Tensor & tensor) +template <class Type> +bool +HasSameDataType(const tensorflow::Tensor & tensor) { return GetTensorflowDataType<Type>() == tensor.dtype(); } +// +// Return the datatype as string +// +tensorflow::string +GetDataTypeAsString(tensorflow::DataType dt) +{ + return tensorflow::DataTypeString(dt); +} + } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowDataTypeBridge.h b/include/otbTensorflowDataTypeBridge.h index 16e9dd23e63beff92f092bfe75fe899afe78b478..e815dafcba8bbb408a843e3ed63aa9a5a8b8dfe3 100644 --- a/include/otbTensorflowDataTypeBridge.h +++ b/include/otbTensorflowDataTypeBridge.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -16,16 +16,24 @@ #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/tensor.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // returns the datatype used by tensorflow -template<class Type> -tensorflow::DataType GetTensorflowDataType(); +template <class Type> +tensorflow::DataType +GetTensorflowDataType(); // Return true if the tensor data type is correct -template<class Type> -bool HasSameDataType(const tensorflow::Tensor & tensor); +template <class Type> +bool +HasSameDataType(const tensorflow::Tensor & tensor); + +// Return datatype as string +tensorflow::string +GetDataTypeAsString(tensorflow::DataType dt); } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowGraphOperations.cxx b/include/otbTensorflowGraphOperations.cxx index 16300f6c341a75a65947ed5bb7512d7d1187dbc7..8c91434061015a866a42f7f29654cd4451aa679f 100644 --- a/include/otbTensorflowGraphOperations.cxx +++ b/include/otbTensorflowGraphOperations.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -11,161 +11,180 @@ =========================================================================*/ #include "otbTensorflowGraphOperations.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ + // -// Restore a model from a path +// Load SavedModel variables // -void RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) { tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); checkpointPathTensor.scalar<tensorflow::tstring>()() = path; - std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = - {{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; - auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().restore_op_name()}, nullptr); + std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = { + { bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor } + }; + auto status = bundle.session->Run(feed_dict, {}, { bundle.meta_graph_def.saver_def().restore_op_name() }, nullptr); if (!status.ok()) - { - itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() ); - } + { + itkGenericExceptionMacro("Can't restore the input model: " << status.ToString()); + } } // -// Restore a model from a path +// Save SavedModel variables // -void SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) { tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); checkpointPathTensor.scalar<tensorflow::tstring>()() = path; - std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = - {{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; - auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().save_tensor_name()}, nullptr); + std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = { + { bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor } + }; + auto status = bundle.session->Run(feed_dict, {}, { bundle.meta_graph_def.saver_def().save_tensor_name() }, nullptr); if (!status.ok()) - { - itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() ); - } + { + itkGenericExceptionMacro("Can't restore the input model: " << status.ToString()); + } } // -// Load a session and a graph from a folder +// Load a SavedModel // -void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::vector<std::string> tagList) { + // If the tag list is empty, we push back the default tag for model serving + if (tagList.size() == 0) + tagList.push_back(tensorflow::kSavedModelTagServe); + // std::vector --> std::unordered_list + std::unordered_set<std::string> tagSets; + std::copy(tagList.begin(), tagList.end(), std::inserter(tagSets, tagSets.end())); // copy in unordered_set + + // Call to tensorflow::LoadSavedModel tensorflow::RunOptions runoptions; runoptions.set_trace_level(tensorflow::RunOptions_TraceLevel_FULL_TRACE); - auto status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(), runoptions, - path, {tensorflow::kSavedModelTagServe}, &bundle); + auto status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(), runoptions, path, tagSets, &bundle); if (!status.ok()) - { - itkGenericExceptionMacro("Can't load the input model: " << status.ToString() ); - } - + { + itkGenericExceptionMacro("Can't load the input model: " << status.ToString()); + } } -// -// Load a graph from a .meta file -// -tensorflow::GraphDef LoadGraph(std::string filename) -{ - tensorflow::MetaGraphDef meta_graph_def; - auto status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &meta_graph_def); - if (!status.ok()) - { - itkGenericExceptionMacro("Can't load the input model: " << status.ToString() ); - } - - return meta_graph_def.graph_def(); -} -// // Get the following attributes of the specified tensors (by name) of a graph: +// - layer name, as specified in the model // - shape // - datatype -// Here we assume that the node's output is a tensor -// -void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & tensorsNames, - std::vector<tensorflow::TensorShapeProto> & shapes, std::vector<tensorflow::DataType> & dataTypes) +void +GetTensorAttributes(const tensorflow::protobuf::Map<std::string, tensorflow::TensorInfo> layers, + std::vector<std::string> & tensorsNames, + std::vector<std::string> & layerNames, + std::vector<tensorflow::TensorShapeProto> & shapes, + std::vector<tensorflow::DataType> & dataTypes) { // Allocation shapes.clear(); - shapes.reserve(tensorsNames.size()); dataTypes.clear(); - dataTypes.reserve(tensorsNames.size()); + layerNames.clear(); - // Get infos - for (std::vector<std::string>::iterator nameIt = tensorsNames.begin(); - nameIt != tensorsNames.end(); ++nameIt) - { - bool found = false; - for (int i = 0 ; i < graph.node_size() ; i++) - { - tensorflow::NodeDef node = graph.node(i); + // Debug infos + otbLogMacro(Debug, << "Nodes contained in the model: "); + for (auto const & layer : layers) + otbLogMacro(Debug, << "\t" << layer.first); + // When the user doesn't specify output.names, m_OutputTensors defaults to an empty list that we can not iterate over. + // We change it to a list containing an empty string [""] + if (tensorsNames.size() == 0) + { + otbLogMacro(Debug, << "No output.name specified. Using a default list with one empty string."); + tensorsNames.push_back(""); + } - if (node.name().compare((*nameIt)) == 0) - { - found = true; + // Next, we fill layerNames + int k = 0; // counter used for tensorsNames + for (auto const & name : tensorsNames) + { + bool found = false; + tensorflow::TensorInfo tensor_info; - // Set default to DT_FLOAT - tensorflow::DataType ts_dt = tensorflow::DT_FLOAT; + // If the user didn't specify the placeholdername, choose the kth layer inside the model + if (name.size() == 0) + { + found = true; + // select the k-th element of `layers` + auto it = layers.begin(); + std::advance(it, k); + layerNames.push_back(it->second.name()); + tensor_info = it->second; + otbLogMacro(Debug, << "Input " << k << " corresponds to " << it->first << " in the model"); + } - // Default (input?) tensor type - auto test_is_output = node.attr().find("T"); - if (test_is_output != node.attr().end()) - { - ts_dt = node.attr().at("T").type(); - } - auto test_has_dtype = node.attr().find("dtype"); - if (test_has_dtype != node.attr().end()) - { - ts_dt = node.attr().at("dtype").type(); - } - auto test_output_type = node.attr().find("output_type"); - if (test_output_type != node.attr().end()) + // Else, if the user specified the placeholdername, find the corresponding layer inside the model + else + { + otbLogMacro(Debug, << "Searching for corresponding node of: " << name << "... "); + for (auto const & layer : layers) + { + // layer is a pair (name, tensor_info) + // cf https://stackoverflow.com/questions/63181951/how-to-get-graph-or-graphdef-from-a-given-model + std::string layername = layer.first; + if (layername.substr(0, layername.find(":")).compare(name) == 0) { - // if there is an output type, we take it instead of the - // datatype of the input tensor - ts_dt = node.attr().at("output_type").type(); + found = true; + layerNames.push_back(layer.second.name()); + tensor_info = layer.second; + otbLogMacro(Debug, << "Found: " << layer.second.name() << " in the model"); } - dataTypes.push_back(ts_dt); + } // next layer + } // end else - // Get the tensor's shape - // Here we assure it's a tensor, with 1 shape - tensorflow::TensorShapeProto ts_shp = node.attr().at("_output_shapes").list().shape(0); - shapes.push_back(ts_shp); - } - } + k += 1; if (!found) { - itkGenericExceptionMacro("Tensor name \"" << (*nameIt) << "\" not found" ); + itkGenericExceptionMacro("Tensor name \"" << name << "\" not found. \n" + << "You can list all inputs/outputs of your SavedModel by " + << "running: \n\t `saved_model_cli show --dir your_model_dir --all`"); } - } + // Default tensor type + tensorflow::DataType ts_dt = tensor_info.dtype(); + dataTypes.push_back(ts_dt); + // Get the tensor's shape + // Here we assure it's a tensor, with 1 shape + tensorflow::TensorShapeProto ts_shp = tensor_info.tensor_shape(); + shapes.push_back(ts_shp); + } // next tensor name } // // Print a lot of stuff about the specified nodes of the graph // -void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & nodesNames) +void +PrintNodeAttributes(const tensorflow::GraphDef & graph, const std::vector<std::string> & nodesNames) { std::cout << "Go through graph:" << std::endl; std::cout << "#\tname" << std::endl; - for (int i = 0 ; i < graph.node_size() ; i++) + for (int i = 0; i < graph.node_size(); i++) { tensorflow::NodeDef node = graph.node(i); std::cout << i << "\t" << node.name() << std::endl; - for (std::vector<std::string>::iterator nameIt = nodesNames.begin(); - nameIt != nodesNames.end(); ++nameIt) + for (auto const & name : nodesNames) { - if (node.name().compare((*nameIt)) == 0) + if (node.name().compare(name) == 0) { std::cout << "Node " << i << " : " << std::endl; std::cout << "\tName: " << node.name() << std::endl; - std::cout << "\tinput_size() : " << node.input_size() << std::endl; + std::cout << "\tinput_size(): " << node.input_size() << std::endl; std::cout << "\tPrintDebugString --------------------------------"; std::cout << std::endl; node.PrintDebugString(); @@ -173,20 +192,19 @@ void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::st // display all attributes of the node std::cout << "\tAttributes of the node: " << std::endl; - for (auto attr = node.attr().begin() ; attr != node.attr().end() ; attr++) + for (auto attr = node.attr().begin(); attr != node.attr().end(); attr++) { - std::cout << "\t\tKey :" << attr->first << std::endl; - std::cout << "\t\tValue.value_case() :" << attr->second.value_case() << std::endl; + std::cout << "\t\tKey: " << attr->first << std::endl; + std::cout << "\t\tValue.value_case(): " << attr->second.value_case() << std::endl; std::cout << "\t\tPrintDebugString --------------------------------"; std::cout << std::endl; attr->second.PrintDebugString(); std::cout << "\t\t-------------------------------------------------" << std::endl; std::cout << std::endl; } // next attribute - } // node name match - } // next node name - } // next node of the graph - + } // node name match + } // next node name + } // next node of the graph } } // end namespace tf diff --git a/include/otbTensorflowGraphOperations.h b/include/otbTensorflowGraphOperations.h index 4b4e93c016a09eaab5182ceb6800126665ccad92..b249508694406bac2424194b43a3013bb69b2866 100644 --- a/include/otbTensorflowGraphOperations.h +++ b/include/otbTensorflowGraphOperations.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -24,30 +24,39 @@ // ITK exception #include "itkMacro.h" -namespace otb { -namespace tf { +// OTB log +#include "otbMacro.h" -// Restore a model from a path -void RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); +namespace otb +{ +namespace tf +{ -// Restore a model from a path -void SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); +// Load SavedModel variables +void +RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); -// Load a session and a graph from a folder -void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); +// Save SavedModel variables +void +SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); -// Load a graph from a .meta file -tensorflow::GraphDef LoadGraph(std::string filename); +// Load SavedModel +void +LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::vector<std::string> tagList); // Get the following attributes of the specified tensors (by name) of a graph: // - shape // - datatype // Here we assume that the node's output is a tensor -void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & tensorsNames, - std::vector<tensorflow::TensorShapeProto> & shapes, std::vector<tensorflow::DataType> & dataTypes); +void +GetTensorAttributes(const tensorflow::protobuf::Map<std::string, tensorflow::TensorInfo> layers, + std::vector<std::string> & tensorsNames, + std::vector<tensorflow::TensorShapeProto> & shapes, + std::vector<tensorflow::DataType> & dataTypes); // Print a lot of stuff about the specified nodes of the graph -void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & nodesNames); +void +PrintNodeAttributes(const tensorflow::GraphDef & graph, const std::vector<std::string> & nodesNames); } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelBase.h b/include/otbTensorflowMultisourceModelBase.h index 9fb0a79ce7e48144af182dc02063f3c8911c7b7b..a1418b0c66e5d420e88690a15d1f89ff90b71503 100644 --- a/include/otbTensorflowMultisourceModelBase.h +++ b/include/otbTensorflowMultisourceModelBase.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -15,10 +15,12 @@ #include "itkProcessObject.h" #include "itkNumericTraits.h" #include "itkSimpleDataObjectDecorator.h" +#include "itkImageToImageFilter.h" // Tensorflow #include "tensorflow/core/public/session.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/cc/saved_model/signature_constants.h" // Tensorflow helpers #include "otbTensorflowGraphOperations.h" @@ -45,8 +47,7 @@ namespace otb * be the same. If not, an exception will be thrown during the method * GenerateOutputInformation(). * - * The TensorFlow graph and session must be set using the SetGraph() and - * SetSession() methods. + * The TensorFlow SavedModel pointer must be set using the SetSavedModel() method. * * Target nodes names of the TensorFlow graph that must be triggered can be set * with the SetTargetNodesNames. @@ -64,34 +65,32 @@ namespace otb * * \ingroup OTBTensorflow */ -template <class TInputImage, class TOutputImage=TInputImage> -class ITK_EXPORT TensorflowMultisourceModelBase : -public itk::ImageToImageFilter<TInputImage, TOutputImage> +template <class TInputImage, class TOutputImage = TInputImage> +class ITK_EXPORT TensorflowMultisourceModelBase : public itk::ImageToImageFilter<TInputImage, TOutputImage> { public: - /** Standard class typedefs. */ - typedef TensorflowMultisourceModelBase Self; + typedef TensorflowMultisourceModelBase Self; typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Run-time type information (and related methods). */ itkTypeMacro(TensorflowMultisourceModelBase, itk::ImageToImageFilter); /** Images typedefs */ - typedef TInputImage ImageType; - typedef typename TInputImage::Pointer ImagePointerType; - typedef typename TInputImage::PixelType PixelType; - typedef typename TInputImage::InternalPixelType InternalPixelType; - typedef typename TInputImage::IndexType IndexType; - typedef typename TInputImage::IndexValueType IndexValueType; - typedef typename TInputImage::PointType PointType; - typedef typename TInputImage::SizeType SizeType; - typedef typename TInputImage::SizeValueType SizeValueType; - typedef typename TInputImage::SpacingType SpacingType; - typedef typename TInputImage::RegionType RegionType; + typedef TInputImage ImageType; + typedef typename TInputImage::Pointer ImagePointerType; + typedef typename TInputImage::PixelType PixelType; + typedef typename TInputImage::InternalPixelType InternalPixelType; + typedef typename TInputImage::IndexType IndexType; + typedef typename TInputImage::IndexValueType IndexValueType; + typedef typename TInputImage::PointType PointType; + typedef typename TInputImage::SizeType SizeType; + typedef typename TInputImage::SizeValueType SizeValueType; + typedef typename TInputImage::SpacingType SpacingType; + typedef typename TInputImage::RegionType RegionType; /** Typedefs for parameters */ typedef std::pair<std::string, tensorflow::Tensor> DictElementType; @@ -103,14 +102,26 @@ public: typedef std::vector<tensorflow::Tensor> TensorListType; /** Set and Get the Tensorflow session and graph */ - void SetGraph(tensorflow::GraphDef graph) { m_Graph = graph; } - tensorflow::GraphDef GetGraph() { return m_Graph ; } - void SetSession(tensorflow::Session * session) { m_Session = session; } - tensorflow::Session * GetSession() { return m_Session; } + void + SetSavedModel(tensorflow::SavedModelBundle * saved_model) + { + m_SavedModel = saved_model; + } + tensorflow::SavedModelBundle * + GetSavedModel() + { + return m_SavedModel; + } + + /** Get the SignatureDef */ + tensorflow::SignatureDef + GetSignatureDef(); /** Model parameters */ - void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image); - void PushBackOuputTensorBundle(std::string name, SizeType expressionField); + void + PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image); + void + PushBackOuputTensorBundle(std::string name, SizeType expressionField); /** Input placeholders names */ itkSetMacro(InputPlaceholders, StringList); @@ -129,8 +140,16 @@ public: itkGetMacro(OutputExpressionFields, SizeListType); /** User placeholders */ - void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; } - DictType GetUserPlaceholders() { return m_UserPlaceholders; } + void + SetUserPlaceholders(const DictType & dict) + { + m_UserPlaceholders = dict; + } + DictType + GetUserPlaceholders() + { + return m_UserPlaceholders; + } /** Target nodes names */ itkSetMacro(TargetNodesNames, StringList); @@ -142,37 +161,44 @@ public: itkGetMacro(InputTensorsShapes, TensorShapeProtoList); itkGetMacro(OutputTensorsShapes, TensorShapeProtoList); - virtual void GenerateOutputInformation(); + virtual void + GenerateOutputInformation(); protected: TensorflowMultisourceModelBase(); - virtual ~TensorflowMultisourceModelBase() {}; + virtual ~TensorflowMultisourceModelBase(){}; - virtual std::stringstream GenerateDebugReport(DictType & inputs); + virtual std::stringstream + GenerateDebugReport(DictType & inputs); - virtual void RunSession(DictType & inputs, TensorListType & outputs); + virtual void + RunSession(DictType & inputs, TensorListType & outputs); private: - TensorflowMultisourceModelBase(const Self&); //purposely not implemented - void operator=(const Self&); //purposely not implemented + TensorflowMultisourceModelBase(const Self &); // purposely not implemented + void + operator=(const Self &); // purposely not implemented // Tensorflow graph and session - tensorflow::GraphDef m_Graph; // The TensorFlow graph - tensorflow::Session * m_Session; // The TensorFlow session + tensorflow::SavedModelBundle * m_SavedModel; // The TensorFlow model // Model parameters - StringList m_InputPlaceholders; // Input placeholders names - SizeListType m_InputReceptiveFields; // Input receptive fields - StringList m_OutputTensors; // Output tensors names - SizeListType m_OutputExpressionFields; // Output expression fields - DictType m_UserPlaceholders; // User placeholders - StringList m_TargetNodesNames; // User nodes target + StringList m_InputPlaceholders; // Input placeholders names + SizeListType m_InputReceptiveFields; // Input receptive fields + StringList m_OutputTensors; // Output tensors names + SizeListType m_OutputExpressionFields; // Output expression fields + DictType m_UserPlaceholders; // User placeholders + StringList m_TargetNodesNames; // User nodes target // Internal, read-only - DataTypeListType m_InputTensorsDataTypes; // Input tensors datatype - DataTypeListType m_OutputTensorsDataTypes; // Output tensors datatype - TensorShapeProtoList m_InputTensorsShapes; // Input tensors shapes - TensorShapeProtoList m_OutputTensorsShapes; // Output tensors shapes + DataTypeListType m_InputTensorsDataTypes; // Input tensors datatype + DataTypeListType m_OutputTensorsDataTypes; // Output tensors datatype + TensorShapeProtoList m_InputTensorsShapes; // Input tensors shapes + TensorShapeProtoList m_OutputTensorsShapes; // Output tensors shapes + + // Layer names inside the model corresponding to inputs and outputs + StringList m_InputLayers; // List of input names, as contained in the model + StringList m_OutputLayers; // List of output names, as contained in the model }; // end class diff --git a/include/otbTensorflowMultisourceModelBase.hxx b/include/otbTensorflowMultisourceModelBase.hxx index 02c98baeea3ec32f36f6a707a6a99d75fc99e79c..ccefde8f178935ccde6f00c13157b4ba95f1a929 100644 --- a/include/otbTensorflowMultisourceModelBase.hxx +++ b/include/otbTensorflowMultisourceModelBase.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,29 +18,56 @@ namespace otb { template <class TInputImage, class TOutputImage> -TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::TensorflowMultisourceModelBase() - { - m_Session = nullptr; - Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max() ); - Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max() ); - } +TensorflowMultisourceModelBase<TInputImage, TOutputImage>::TensorflowMultisourceModelBase() +{ + Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max()); + Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max()); + + m_SavedModel = NULL; +} + +template <class TInputImage, class TOutputImage> +tensorflow::SignatureDef +TensorflowMultisourceModelBase<TInputImage, TOutputImage>::GetSignatureDef() +{ + auto signatures = this->GetSavedModel()->GetSignatures(); + tensorflow::SignatureDef signature_def; + + if (signatures.size() == 0) + { + itkExceptionMacro("There are no available signatures for this tag-set. \n" + << "Please check which tag-set to use by running " + << "`saved_model_cli show --dir your_model_dir --all`"); + } + + // If serving_default key exists (which is the default for TF saved model), choose it as signature + // Else, choose the first one + if (signatures.contains(tensorflow::kDefaultServingSignatureDefKey)) + { + signature_def = signatures.at(tensorflow::kDefaultServingSignatureDefKey); + } + else + { + signature_def = signatures.begin()->second; + } + return signature_def; +} template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image) - { +TensorflowMultisourceModelBase<TInputImage, TOutputImage>::PushBackInputTensorBundle(std::string placeholder, + SizeType receptiveField, + ImagePointerType image) +{ Superclass::PushBackInput(image); m_InputReceptiveFields.push_back(receptiveField); m_InputPlaceholders.push_back(placeholder); - } +} template <class TInputImage, class TOutputImage> std::stringstream -TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::GenerateDebugReport(DictType & inputs) - { +TensorflowMultisourceModelBase<TInputImage, TOutputImage>::GenerateDebugReport(DictType & inputs) +{ // Create a debug report std::stringstream debugReport; @@ -50,62 +77,72 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> debugReport << "Output image buffered region: " << outputReqRegion << "\n"; // Describe inputs - for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { - const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i)); - const RegionType reqRegion = inputPtr->GetRequestedRegion(); + for (unsigned int i = 0; i < this->GetNumberOfInputs(); i++) + { + const ImagePointerType inputPtr = const_cast<TInputImage *>(this->GetInput(i)); + const RegionType reqRegion = inputPtr->GetRequestedRegion(); debugReport << "Input #" << i << ":\n"; debugReport << "Requested region: " << reqRegion << "\n"; - debugReport << "Tensor shape (\"" << inputs[i].first << "\": " << tf::PrintTensorShape(inputs[i].second.shape()) << "\n"; - } + debugReport << "Tensor \"" << inputs[i].first << "\": " << tf::PrintTensorInfos(inputs[i].second) << "\n"; + } // Show user placeholders - debugReport << "User placeholders:\n" ; - for (auto& dict: this->GetUserPlaceholders()) - { - debugReport << dict.first << " " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl; - } + debugReport << "User placeholders:\n"; + for (auto & dict : this->GetUserPlaceholders()) + { + debugReport << "Tensor \"" << dict.first << "\": " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl; + } return debugReport; - } +} template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::RunSession(DictType & inputs, TensorListType & outputs) - { +TensorflowMultisourceModelBase<TInputImage, TOutputImage>::RunSession(DictType & inputs, TensorListType & outputs) +{ // Add the user's placeholders - for (auto& dict: this->GetUserPlaceholders()) - { - inputs.push_back(dict); - } + std::copy(this->GetUserPlaceholders().begin(), this->GetUserPlaceholders().end(), std::back_inserter(inputs)); // Run the TF session here // The session will initialize the outputs + // `inputs` corresponds to a mapping {name, tensor}, with the name being specified by the user when calling + // TensorFlowModelServe we must adapt it to `inputs_new`, that corresponds to a mapping {layerName, tensor}, with the + // layerName being from the model + DictType inputs_new; + int k = 0; + for (auto & dict : inputs) + { + DictElementType element = { m_InputLayers[k], dict.second }; + inputs_new.push_back(element); + k += 1; + } + // Run the session, evaluating our output tensors from the graph - auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs); - if (!status.ok()) { + auto status = this->GetSavedModel()->session.get()->Run(inputs_new, m_OutputLayers, m_TargetNodesNames, &outputs); + + if (!status.ok()) + { // Create a debug report std::stringstream debugReport = GenerateDebugReport(inputs); // Throw an exception with the report - itkExceptionMacro("Can't run the tensorflow session !\n" << - "Tensorflow error message:\n" << status.ToString() << "\n" - "OTB Filter debug message:\n" << debugReport.str() ); - + itkExceptionMacro("Can't run the tensorflow session !\n" + << "Tensorflow error message:\n" + << status.ToString() + << "\n" + "OTB Filter debug message:\n" + << debugReport.str()); } - - } +} template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::GenerateOutputInformation() - { +TensorflowMultisourceModelBase<TInputImage, TOutputImage>::GenerateOutputInformation() +{ // Check that the number of the following is the same // - input placeholders names @@ -113,30 +150,29 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> // - input images const unsigned int nbInputs = this->GetNumberOfInputs(); if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size()) - { - itkExceptionMacro("Number of input images is " << nbInputs << - " but the number of input patches size is " << m_InputReceptiveFields.size() << - " and the number of input tensors names is " << m_InputPlaceholders.size()); - } - - // Check that the number of the following is the same - // - output tensors names - // - output expression fields - if (m_OutputExpressionFields.size() != m_OutputTensors.size()) - { - itkExceptionMacro("Number of output tensors names is " << m_OutputTensors.size() << - " but the number of output fields of expression is " << m_OutputExpressionFields.size()); - } + { + itkExceptionMacro("Number of input images is " + << nbInputs << " but the number of input patches size is " << m_InputReceptiveFields.size() + << " and the number of input tensors names is " << m_InputPlaceholders.size()); + } ////////////////////////////////////////////////////////////////////////////////////////// // Get tensors information ////////////////////////////////////////////////////////////////////////////////////////// - - // Get input and output tensors datatypes and shapes - tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes); - tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes); - - } + // Set all subelement of the model + auto signaturedef = this->GetSignatureDef(); + + // Given the inputs/outputs names that the user specified, get the names of the inputs/outputs contained in the model + // and other infos (shapes, dtypes) + // For example, for output names specified by the user m_OutputTensors = ['s2t', 's2t_pad'], + // this will return m_OutputLayers = ['PartitionedCall:0', 'PartitionedCall:1'] + // In case the user hasn't named the output, e.g. m_OutputTensors = [''], + // this will return the first output m_OutputLayers = ['PartitionedCall:0'] + tf::GetTensorAttributes( + signaturedef.inputs(), m_InputPlaceholders, m_InputLayers, m_InputTensorsShapes, m_InputTensorsDataTypes); + tf::GetTensorAttributes( + signaturedef.outputs(), m_OutputTensors, m_OutputLayers, m_OutputTensorsShapes, m_OutputTensorsDataTypes); +} } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelFilter.h b/include/otbTensorflowMultisourceModelFilter.h index 46a273af36f9a4aefd6a365ca36b9557eebab37e..bdf9a02d0b00e9dbc228f3a0401ac8dab4c49e32 100644 --- a/include/otbTensorflowMultisourceModelFilter.h +++ b/include/otbTensorflowMultisourceModelFilter.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -17,15 +17,13 @@ // Iterator #include "itkImageRegionConstIteratorWithOnlyIndex.h" -// Tensorflow helpers -#include "otbTensorflowGraphOperations.h" -#include "otbTensorflowDataTypeBridge.h" -#include "otbTensorflowCopyUtils.h" - // Tile hint #include "itkMetaDataObject.h" #include "otbMetaDataKey.h" +// OTB log +#include "otbMacro.h" + namespace otb { @@ -82,12 +80,10 @@ namespace otb * \ingroup OTBTensorflow */ template <class TInputImage, class TOutputImage> -class ITK_EXPORT TensorflowMultisourceModelFilter : -public TensorflowMultisourceModelBase<TInputImage, TOutputImage> +class ITK_EXPORT TensorflowMultisourceModelFilter : public TensorflowMultisourceModelBase<TInputImage, TOutputImage> { public: - /** Standard class typedefs. */ typedef TensorflowMultisourceModelFilter Self; typedef TensorflowMultisourceModelBase<TInputImage, TOutputImage> Superclass; @@ -101,16 +97,16 @@ public: itkTypeMacro(TensorflowMultisourceModelFilter, TensorflowMultisourceModelBase); /** Images typedefs */ - typedef typename Superclass::ImageType ImageType; - typedef typename Superclass::ImagePointerType ImagePointerType; - typedef typename Superclass::PixelType PixelType; - typedef typename Superclass::IndexType IndexType; - typedef typename IndexType::IndexValueType IndexValueType; - typedef typename Superclass::PointType PointType; - typedef typename Superclass::SizeType SizeType; - typedef typename SizeType::SizeValueType SizeValueType; - typedef typename Superclass::SpacingType SpacingType; - typedef typename Superclass::RegionType RegionType; + typedef typename Superclass::ImageType ImageType; + typedef typename Superclass::ImagePointerType ImagePointerType; + typedef typename Superclass::PixelType PixelType; + typedef typename Superclass::IndexType IndexType; + typedef typename IndexType::IndexValueType IndexValueType; + typedef typename Superclass::PointType PointType; + typedef typename Superclass::SizeType SizeType; + typedef typename SizeType::SizeValueType SizeValueType; + typedef typename Superclass::SpacingType SpacingType; + typedef typename Superclass::RegionType RegionType; typedef TOutputImage OutputImageType; typedef typename TOutputImage::PixelType OutputPixelType; @@ -121,12 +117,12 @@ public: typedef typename itk::ImageRegionConstIterator<TInputImage> InputConstIteratorType; /* Typedefs for parameters */ - typedef typename Superclass::DictElementType DictElementType; - typedef typename Superclass::DictType DictType; - typedef typename Superclass::StringList StringList; - typedef typename Superclass::SizeListType SizeListType; - typedef typename Superclass::TensorListType TensorListType; - typedef std::vector<float> ScaleListType; + typedef typename Superclass::DictElementType DictElementType; + typedef typename Superclass::DictType DictType; + typedef typename Superclass::StringList StringList; + typedef typename Superclass::SizeListType SizeListType; + typedef typename Superclass::TensorListType TensorListType; + typedef std::vector<float> ScaleListType; itkSetMacro(OutputGridSize, SizeType); itkGetMacro(OutputGridSize, SizeType); @@ -139,34 +135,43 @@ public: protected: TensorflowMultisourceModelFilter(); - virtual ~TensorflowMultisourceModelFilter() {}; + virtual ~TensorflowMultisourceModelFilter(){}; - virtual void SmartPad(RegionType& region, const SizeType &patchSize); - virtual void SmartShrink(RegionType& region, const SizeType &patchSize); - virtual void ImageToExtent(ImageType* image, PointType &extentInf, PointType &extentSup, SizeType &patchSize); - virtual bool OutputRegionToInputRegion(const RegionType &outputRegion, RegionType &inputRegion, ImageType* &inputImage); - virtual void EnlargeToAlignedRegion(RegionType& region); + virtual void + SmartPad(RegionType & region, const SizeType & patchSize); + virtual void + SmartShrink(RegionType & region, const SizeType & patchSize); + virtual void + ImageToExtent(ImageType * image, PointType & extentInf, PointType & extentSup, SizeType & patchSize); + virtual bool + OutputRegionToInputRegion(const RegionType & outputRegion, RegionType & inputRegion, ImageType *& inputImage); + virtual void + EnlargeToAlignedRegion(RegionType & region); - virtual void GenerateOutputInformation(void); + virtual void + GenerateOutputInformation(void); - virtual void GenerateInputRequestedRegion(void); + virtual void + GenerateInputRequestedRegion(void); - virtual void GenerateData(); + virtual void + GenerateData(); private: - TensorflowMultisourceModelFilter(const Self&); //purposely not implemented - void operator=(const Self&); //purposely not implemented + TensorflowMultisourceModelFilter(const Self &); // purposely not implemented + void + operator=(const Self &); // purposely not implemented - SizeType m_OutputGridSize; // Output grid size - bool m_ForceOutputGridSize; // Force output grid size - bool m_FullyConvolutional; // Convolution mode - float m_OutputSpacingScale; // scaling of the output spacings + SizeType m_OutputGridSize; // Output grid size + bool m_ForceOutputGridSize; // Force output grid size + bool m_FullyConvolutional; // Convolution mode + float m_OutputSpacingScale; // scaling of the output spacings // Internal - SpacingType m_OutputSpacing; // Output image spacing - PointType m_OutputOrigin; // Output image origin - SizeType m_OutputSize; // Output image size - PixelType m_NullPixel; // Pixel filled with zeros + SpacingType m_OutputSpacing; // Output image spacing + PointType m_OutputOrigin; // Output image origin + SizeType m_OutputSize; // Output image size + PixelType m_NullPixel; // Pixel filled with zeros }; // end class diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx index 91c5384ea1b508fcd359c120a067ed99c4a11126..3cbb53d92857466d617e5547940c8e42a0ce971e 100644 --- a/include/otbTensorflowMultisourceModelFilter.hxx +++ b/include/otbTensorflowMultisourceModelFilter.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,9 +18,8 @@ namespace otb { template <class TInputImage, class TOutputImage> -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::TensorflowMultisourceModelFilter() - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::TensorflowMultisourceModelFilter() +{ m_OutputGridSize.Fill(0); m_ForceOutputGridSize = false; m_FullyConvolutional = false; @@ -31,38 +30,37 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> m_OutputSpacingScale = 1.0f; - Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max() ); - Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max() ); - } + Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max()); + Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max()); +} template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::SmartPad(RegionType& region, const SizeType &patchSize) - { - for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::SmartPad(RegionType & region, const SizeType & patchSize) +{ + for (unsigned int dim = 0; dim < OutputImageType::ImageDimension; ++dim) + { const SizeValueType psz = patchSize[dim]; const SizeValueType rval = 0.5 * psz; const SizeValueType lval = psz - rval; region.GetModifiableIndex()[dim] -= lval; region.GetModifiableSize()[dim] += psz; - } - } + } +} template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::SmartShrink(RegionType& region, const SizeType &patchSize) - { - for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::SmartShrink(RegionType & region, + const SizeType & patchSize) +{ + for (unsigned int dim = 0; dim < OutputImageType::ImageDimension; ++dim) + { const SizeValueType psz = patchSize[dim]; const SizeValueType lval = 0.5 * psz; region.GetModifiableIndex()[dim] += lval; region.GetModifiableSize()[dim] -= psz - 1; - } - } + } +} /** Compute the input image extent: corners inf and sup. @@ -70,9 +68,11 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> */ template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::ImageToExtent(ImageType* image, PointType &extentInf, PointType &extentSup, SizeType &patchSize) - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::ImageToExtent(ImageType * image, + PointType & extentInf, + PointType & extentSup, + SizeType & patchSize) +{ // Get largest possible region RegionType largestPossibleRegion = image->GetLargestPossibleRegion(); @@ -89,13 +89,12 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> PointType imageEnd; image->TransformIndexToPhysicalPoint(imageLastIndex, imageEnd); image->TransformIndexToPhysicalPoint(imageFirstIndex, imageOrigin); - for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) - { + for (unsigned int dim = 0; dim < OutputImageType::ImageDimension; ++dim) + { extentInf[dim] = vnl_math_min(imageOrigin[dim], imageEnd[dim]) - 0.5 * image->GetSpacing()[dim]; extentSup[dim] = vnl_math_max(imageOrigin[dim], imageEnd[dim]) + 0.5 * image->GetSpacing()[dim]; - } - - } + } +} /** Compute the region of the input image which correspond to the given output requested region @@ -104,9 +103,10 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> */ template <class TInputImage, class TOutputImage> bool -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::OutputRegionToInputRegion(const RegionType &outputRegion, RegionType &inputRegion, ImageType* &inputImage) - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::OutputRegionToInputRegion(const RegionType & outputRegion, + RegionType & inputRegion, + ImageType *& inputImage) +{ // Mosaic Region Start & End (mosaic image index) const IndexType outIndexStart = outputRegion.GetIndex(); @@ -115,45 +115,43 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Mosaic Region Start & End (geo) PointType outPointStart, outPointEnd; this->GetOutput()->TransformIndexToPhysicalPoint(outIndexStart, outPointStart); - this->GetOutput()->TransformIndexToPhysicalPoint(outIndexEnd , outPointEnd ); + this->GetOutput()->TransformIndexToPhysicalPoint(outIndexEnd, outPointEnd); // Add the half-width pixel size of the input image // and remove the half-width pixel size of the output image // (coordinates = pixel center) const SpacingType outputSpc = this->GetOutput()->GetSpacing(); const SpacingType inputSpc = inputImage->GetSpacing(); - for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) - { - const typename SpacingType::ValueType border = - 0.5 * (inputSpc[dim] - outputSpc[dim]); + for (unsigned int dim = 0; dim < OutputImageType::ImageDimension; ++dim) + { + const typename SpacingType::ValueType border = 0.5 * (inputSpc[dim] - outputSpc[dim]); if (outPointStart[dim] < outPointEnd[dim]) - { + { outPointStart[dim] += border; - outPointEnd [dim] -= border; - } + outPointEnd[dim] -= border; + } else - { + { outPointStart[dim] -= border; - outPointEnd [dim] += border; - } + outPointEnd[dim] += border; } + } // Mosaic Region Start & End (input image index) IndexType defIndexStart, defIndexEnd; inputImage->TransformPhysicalPointToIndex(outPointStart, defIndexStart); - inputImage->TransformPhysicalPointToIndex(outPointEnd , defIndexEnd); + inputImage->TransformPhysicalPointToIndex(outPointEnd, defIndexEnd); // Compute input image region - for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) - { + for (unsigned int dim = 0; dim < OutputImageType::ImageDimension; ++dim) + { inputRegion.SetIndex(dim, vnl_math_min(defIndexStart[dim], defIndexEnd[dim])); inputRegion.SetSize(dim, vnl_math_max(defIndexStart[dim], defIndexEnd[dim]) - inputRegion.GetIndex(dim) + 1); - } + } // crop the input requested region at the input's largest possible region - return inputRegion.Crop( inputImage->GetLargestPossibleRegion() ); - - } + return inputRegion.Crop(inputImage->GetLargestPossibleRegion()); +} /* * Enlarge the given region to the nearest aligned region. @@ -161,11 +159,10 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> */ template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::EnlargeToAlignedRegion(RegionType& region) - { - for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::EnlargeToAlignedRegion(RegionType & region) +{ + for (unsigned int dim = 0; dim < OutputImageType::ImageDimension; ++dim) + { // Get corners IndexValueType lower = region.GetIndex(dim); IndexValueType upper = lower + region.GetSize(dim); @@ -177,22 +174,20 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Move corners to aligned positions lower -= deltaLo; if (deltaUp > 0) - { + { upper += m_OutputGridSize[dim] - deltaUp; - } + } // Update region region.SetIndex(dim, lower); region.SetSize(dim, upper - lower); - - } - } + } +} template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::GenerateOutputInformation() - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::GenerateOutputInformation() +{ Superclass::GenerateOutputInformation(); @@ -201,39 +196,45 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> ////////////////////////////////////////////////////////////////////////////////////////// // If the output spacing is not specified, we use the first input image as grid reference - m_OutputSpacing = this->GetInput(0)->GetSignedSpacing(); + // OTBTF assumes that the output image has the following geometric properties: + // (1) Image origin is the top-left pixel + // (2) Image pixel spacing has positive x-spacing and negative y-spacing + m_OutputSpacing = this->GetInput(0)->GetSpacing(); // GetSpacing() returns abs. spacing + m_OutputSpacing[1] *= -1.0; // Force negative y-spacing m_OutputSpacing[0] *= m_OutputSpacingScale; m_OutputSpacing[1] *= m_OutputSpacingScale; - PointType extentInf, extentSup; - extentSup.Fill(itk::NumericTraits<double>::max()); - extentInf.Fill(itk::NumericTraits<double>::NonpositiveMin()); // Compute the extent of each input images and update the extent or the output image. // The extent of the output image is the intersection of all input images extents. - for (unsigned int imageIndex = 0 ; imageIndex < this->GetNumberOfInputs() ; imageIndex++) - { - ImageType * currentImage = static_cast<ImageType *>( - Superclass::ProcessObject::GetInput(imageIndex) ); + PointType extentInf, extentSup; + extentSup.Fill(itk::NumericTraits<double>::max()); + extentInf.Fill(itk::NumericTraits<double>::NonpositiveMin()); + for (unsigned int imageIndex = 0; imageIndex < this->GetNumberOfInputs(); imageIndex++) + { + ImageType * currentImage = static_cast<ImageType *>(Superclass::ProcessObject::GetInput(imageIndex)); // Update output image extent PointType currentInputImageExtentInf, currentInputImageExtentSup; - ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputReceptiveFields()[imageIndex]); - for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim) - { + ImageToExtent(currentImage, + currentInputImageExtentInf, + currentInputImageExtentSup, + this->GetInputReceptiveFields()[imageIndex]); + for (unsigned int dim = 0; dim < ImageType::ImageDimension; ++dim) + { extentInf[dim] = vnl_math_max(currentInputImageExtentInf[dim], extentInf[dim]); extentSup[dim] = vnl_math_min(currentInputImageExtentSup[dim], extentSup[dim]); - } } + } // Set final origin, aligned to the reference image grid. // Here we simply get back to the center of the pixel (extents are pixels corners coordinates) - m_OutputOrigin[0] = extentInf[0] + 0.5 * this->GetInput(0)->GetSpacing()[0]; - m_OutputOrigin[1] = extentSup[1] - 0.5 * this->GetInput(0)->GetSpacing()[1]; + m_OutputOrigin[0] = extentInf[0] + 0.5 * this->GetInput(0)->GetSpacing()[0]; + m_OutputOrigin[1] = extentSup[1] - 0.5 * this->GetInput(0)->GetSpacing()[1]; // Set final size - m_OutputSize[0] = std::floor( (extentSup[0] - extentInf[0]) / std::abs(m_OutputSpacing[0]) ); - m_OutputSize[1] = std::floor( (extentSup[1] - extentInf[1]) / std::abs(m_OutputSpacing[1]) ); + m_OutputSize[0] = std::floor((extentSup[0] - extentInf[0]) / std::abs(m_OutputSpacing[0])); + m_OutputSize[1] = std::floor((extentSup[1] - extentInf[1]) / std::abs(m_OutputSpacing[1])); // We should take in account one more thing: the expression field. It enlarge slightly the output image extent. m_OutputOrigin[0] -= m_OutputSpacing[0] * std::floor(0.5 * this->GetOutputExpressionFields().at(0)[0]); @@ -243,18 +244,18 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Set output grid size if (!m_ForceOutputGridSize) - { + { // Default is the output field of expression m_OutputGridSize = this->GetOutputExpressionFields().at(0); - } + } // Resize the largestPossibleRegion to be a multiple of the grid size - for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim) - { + for (unsigned int dim = 0; dim < ImageType::ImageDimension; ++dim) + { if (m_OutputGridSize[dim] > m_OutputSize[dim]) itkGenericExceptionMacro("Output grid size is larger than output image size !"); m_OutputSize[dim] -= m_OutputSize[dim] % m_OutputGridSize[dim]; - } + } // Set the largest possible region RegionType largestPossibleRegion; @@ -265,33 +266,39 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> ////////////////////////////////////////////////////////////////////////////////////////// unsigned int outputPixelSize = 0; - for (auto& protoShape: this->GetOutputTensorsShapes()) + for (auto & protoShape : this->GetOutputTensorsShapes()) + { + // Find the number of components + if (protoShape.dim_size() > 4) { - // The number of components per pixel is the last dimension of the tensor - int dim_size = protoShape.dim_size(); - unsigned int nComponents = 1; - if (1 < dim_size && dim_size <= 4) - { - nComponents = protoShape.dim(dim_size-1).size(); - } - else if (dim_size > 4) - { - itkExceptionMacro("Dim_size=" << dim_size << " currently not supported."); - } - outputPixelSize += nComponents; + itkExceptionMacro("dim_size=" << protoShape.dim_size() + << " currently not supported. " + "Keep in mind that output tensors must have 1, 2, 3 or 4 dimensions. " + "In the case of 1-dimensional tensor, the first dimension is for the batch, " + "and we assume that the output tensor has 1 channel. " + "In the case of 2-dimensional tensor, the first dimension is for the batch, " + "and the second is the number of components. " + "In the case of 3-dimensional tensor, the first dimension is for the batch, " + "and other dims are for (x, y). " + "In the case of 4-dimensional tensor, the first dimension is for the batch, " + "and the second and the third are for (x, y). The last is for the number of " + "channels. "); } + unsigned int nComponents = tf::GetNumberOfChannelsFromShapeProto(protoShape); + outputPixelSize += nComponents; + } // Copy input image projection - ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) ); + ImageType * inputImage = static_cast<ImageType *>(Superclass::ProcessObject::GetInput(0)); const std::string projectionRef = inputImage->GetProjectionRef(); // Set output image origin/spacing/size/projection ImageType * outputPtr = this->GetOutput(); outputPtr->SetNumberOfComponentsPerPixel(outputPixelSize); - outputPtr->SetProjectionRef ( projectionRef ); - outputPtr->SetOrigin ( m_OutputOrigin ); - outputPtr->SetSignedSpacing ( m_OutputSpacing ); - outputPtr->SetLargestPossibleRegion( largestPossibleRegion ); + outputPtr->SetProjectionRef(projectionRef); + outputPtr->SetOrigin(m_OutputOrigin); + outputPtr->SetSignedSpacing(m_OutputSpacing); + outputPtr->SetLargestPossibleRegion(largestPossibleRegion); // Set null pixel m_NullPixel.SetSize(outputPtr->GetNumberOfComponentsPerPixel()); @@ -303,14 +310,12 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> itk::EncapsulateMetaData(outputPtr->GetMetaDataDictionary(), MetaDataKey::TileHintX, m_OutputGridSize[0]); itk::EncapsulateMetaData(outputPtr->GetMetaDataDictionary(), MetaDataKey::TileHintY, m_OutputGridSize[1]); - - } +} template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::GenerateInputRequestedRegion() - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::GenerateInputRequestedRegion() +{ Superclass::GenerateInputRequestedRegion(); // Output requested region @@ -320,35 +325,37 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> EnlargeToAlignedRegion(requestedRegion); // For each image, get the requested region - for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i) - { - ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) ); + for (unsigned int i = 0; i < this->GetNumberOfInputs(); ++i) + { + ImageType * inputImage = static_cast<ImageType *>(Superclass::ProcessObject::GetInput(i)); // Compute the requested region RegionType inRegion; - if (!OutputRegionToInputRegion(requestedRegion, inRegion, inputImage) ) - { + if (!OutputRegionToInputRegion(requestedRegion, inRegion, inputImage)) + { // Image does not overlap requested region: set requested region to null - itkDebugMacro( << "Image #" << i << " :\n" << inRegion << " is outside the requested region"); + otbLogMacro(Debug, << "Image #" << i << " :\n" << inRegion << " is outside the requested region"); inRegion.GetModifiableIndex().Fill(0); inRegion.GetModifiableSize().Fill(0); - } + } // Compute the FOV-scale*FOE radius to pad SizeType toPad(this->GetInputReceptiveFields().at(i)); - for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim) - { - int valToPad = 1 + (this->GetOutputExpressionFields().at(0)[dim] - 1) * m_OutputSpacingScale * this->GetInput(0)->GetSpacing()[dim] / this->GetInput(i)->GetSpacing()[dim] ; + for (unsigned int dim = 0; dim < ImageType::ImageDimension; ++dim) + { + int valToPad = 1 + (this->GetOutputExpressionFields().at(0)[dim] - 1) * m_OutputSpacingScale * + this->GetInput(0)->GetSpacing()[dim] / this->GetInput(i)->GetSpacing()[dim]; if (valToPad > toPad[dim]) - itkExceptionMacro("The input requested region of source #" << i << " is not consistent (dim "<< dim<< ")." << - "Please check RF, EF, SF vs physical spacing of your image!" << - "\nReceptive field: " << this->GetInputReceptiveFields().at(i)[dim] << - "\nExpression field: " << this->GetOutputExpressionFields().at(0)[dim] << - "\nScale factor: " << m_OutputSpacingScale << - "\nReference image spacing: " << this->GetInput(0)->GetSpacing()[dim] << - "\nImage " << i << " spacing: " << this->GetInput(i)->GetSpacing()[dim]); + itkExceptionMacro("The input requested region of source #" + << i << " is not consistent (dim " << dim << ")." + << "Please check RF, EF, SF vs physical spacing of your image!" + << "\nReceptive field: " << this->GetInputReceptiveFields().at(i)[dim] + << "\nExpression field: " << this->GetOutputExpressionFields().at(0)[dim] + << "\nScale factor: " << m_OutputSpacingScale + << "\nReference image spacing: " << this->GetInput(0)->GetSpacing()[dim] << "\nImage " << i + << " spacing: " << this->GetInput(i)->GetSpacing()[dim]); toPad[dim] -= valToPad; - } + } // Pad with radius SmartPad(inRegion, toPad); @@ -359,30 +366,28 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // can be one pixel larger when the input image regions are not physically // aligned. if (!m_FullyConvolutional) - { + { inRegion.PadByRadius(1); - } + } inRegion.Crop(inputImage->GetLargestPossibleRegion()); // Update the requested region inputImage->SetRequestedRegion(inRegion); - } // next image - - } + } // next image +} /** * Compute the output image */ template <class TInputImage, class TOutputImage> void -TensorflowMultisourceModelFilter<TInputImage, TOutputImage> -::GenerateData() - { +TensorflowMultisourceModelFilter<TInputImage, TOutputImage>::GenerateData() +{ // Output pointer and requested region typename TOutputImage::Pointer outputPtr = this->GetOutput(); - const RegionType outputReqRegion = outputPtr->GetRequestedRegion(); + const RegionType outputReqRegion = outputPtr->GetRequestedRegion(); // Get the aligned output requested region RegionType outputAlignedReqRegion(outputReqRegion); @@ -393,11 +398,12 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Create input tensors list DictType inputs; + // Populate input tensors - for (unsigned int i = 0 ; i < nInputs ; i++) - { + for (unsigned int i = 0; i < nInputs; i++) + { // Input image pointer - const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i)); + const ImagePointerType inputPtr = const_cast<TInputImage *>(this->GetInput(i)); // Patch size of tensor #i const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i); @@ -406,13 +412,13 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> const RegionType reqRegion = inputPtr->GetRequestedRegion(); if (m_FullyConvolutional) - { + { // Shape of input tensor #i - tensorflow::int64 sz_n = 1; - tensorflow::int64 sz_y = reqRegion.GetSize(1); - tensorflow::int64 sz_x = reqRegion.GetSize(0); - tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); - tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c}); + tensorflow::int64 sz_n = 1; + tensorflow::int64 sz_y = reqRegion.GetSize(1); + tensorflow::int64 sz_x = reqRegion.GetSize(0); + tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); + tensorflow::TensorShape inputTensorShape({ sz_n, sz_y, sz_x, sz_c }); // Create the input tensor tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape); @@ -423,16 +429,16 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Input is the tensor representing the subset of image DictElementType input = { this->GetInputPlaceholders()[i], inputTensor }; inputs.push_back(input); - } + } else - { + { // Preparing patches // Shape of input tensor #i - tensorflow::int64 sz_n = outputReqRegion.GetNumberOfPixels(); - tensorflow::int64 sz_y = inputPatchSize[1]; - tensorflow::int64 sz_x = inputPatchSize[0]; - tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); - tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c}); + tensorflow::int64 sz_n = outputReqRegion.GetNumberOfPixels(); + tensorflow::int64 sz_y = inputPatchSize[1]; + tensorflow::int64 sz_x = inputPatchSize[0]; + tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); + tensorflow::TensorShape inputTensorShape({ sz_n, sz_y, sz_x, sz_c }); // Create the input tensor tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape); @@ -440,10 +446,10 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Fill the input tensor. // We iterate over points which are located from the index iterator // moving through the output image requested region - unsigned int elemIndex = 0; + unsigned int elemIndex = 0; IndexIteratorType idxIt(outputPtr, outputReqRegion); for (idxIt.GoToBegin(); !idxIt.IsAtEnd(); ++idxIt) - { + { // Get the coordinates of the current output pixel PointType point; outputPtr->TransformIndexToPhysicalPoint(idxIt.GetIndex(), point); @@ -451,17 +457,18 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Sample the i-th input patch centered on the point tf::SampleCenteredPatch<TInputImage>(inputPtr, point, inputPatchSize, inputTensor, elemIndex); elemIndex++; - } + } // Input is the tensor of patches (aka the batch) DictElementType input = { this->GetInputPlaceholders()[i], inputTensor }; inputs.push_back(input); - } // mode is not full convolutional + } // mode is not full convolutional - } // next input tensor + } // next input tensor // Run session + // TODO: see if we print some info about inputs/outputs of the model e.g. m_OutputTensors TensorListType outputs; this->RunSession(inputs, outputs); @@ -472,26 +479,25 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Get output tensors int bandOffset = 0; - for (unsigned int i = 0 ; i < outputs.size() ; i++) - { + for (unsigned int i = 0; i < outputs.size(); i++) + { // The offset (i.e. the starting index of the channel for the output tensor) is updated // during this call - // TODO: implement a generic strategy enabling expression field copy in patch-based mode (see tf::CopyTensorToImageRegion) + // TODO: implement a generic strategy enabling expression field copy in patch-based mode (see + // tf::CopyTensorToImageRegion) try - { - tf::CopyTensorToImageRegion<TOutputImage> (outputs[i], - outputAlignedReqRegion, outputPtr, outputReqRegion, bandOffset); - } - catch( itk::ExceptionObject & err ) - { + { + tf::CopyTensorToImageRegion<TOutputImage>( + outputs[i], outputAlignedReqRegion, outputPtr, outputReqRegion, bandOffset); + } + catch (itk::ExceptionObject & err) + { std::stringstream debugMsg = this->GenerateDebugReport(inputs); - itkExceptionMacro("Error occured during tensor to image conversion.\n" - << "Context: " << debugMsg.str() - << "Error:" << err); - } + itkExceptionMacro("Error occurred during tensor to image conversion.\n" + << "Context: " << debugMsg.str() << "Error:" << err); } - - } + } +} } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelLearningBase.h b/include/otbTensorflowMultisourceModelLearningBase.h index ba130453ff73a810e3cb25a15e577bea63c78377..6e01317db89d235e7ddae6740136d28f6470cc59 100644 --- a/include/otbTensorflowMultisourceModelLearningBase.h +++ b/include/otbTensorflowMultisourceModelLearningBase.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -53,37 +53,35 @@ namespace otb * \ingroup OTBTensorflow */ template <class TInputImage> -class ITK_EXPORT TensorflowMultisourceModelLearningBase : -public TensorflowMultisourceModelBase<TInputImage> +class ITK_EXPORT TensorflowMultisourceModelLearningBase : public TensorflowMultisourceModelBase<TInputImage> { public: - /** Standard class typedefs. */ - typedef TensorflowMultisourceModelLearningBase Self; - typedef TensorflowMultisourceModelBase<TInputImage> Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef TensorflowMultisourceModelLearningBase Self; + typedef TensorflowMultisourceModelBase<TInputImage> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Run-time type information (and related methods). */ itkTypeMacro(TensorflowMultisourceModelLearningBase, TensorflowMultisourceModelBase); /** Images typedefs */ - typedef typename Superclass::ImageType ImageType; - typedef typename Superclass::ImagePointerType ImagePointerType; - typedef typename Superclass::RegionType RegionType; - typedef typename Superclass::SizeType SizeType; - typedef typename Superclass::IndexType IndexType; + typedef typename Superclass::ImageType ImageType; + typedef typename Superclass::ImagePointerType ImagePointerType; + typedef typename Superclass::RegionType RegionType; + typedef typename Superclass::SizeType SizeType; + typedef typename Superclass::IndexType IndexType; /* Typedefs for parameters */ - typedef typename Superclass::DictType DictType; - typedef typename Superclass::DictElementType DictElementType; - typedef typename Superclass::StringList StringList; - typedef typename Superclass::SizeListType SizeListType; - typedef typename Superclass::TensorListType TensorListType; + typedef typename Superclass::DictType DictType; + typedef typename Superclass::DictElementType DictElementType; + typedef typename Superclass::StringList StringList; + typedef typename Superclass::SizeListType SizeListType; + typedef typename Superclass::TensorListType TensorListType; /* Typedefs for index */ - typedef typename ImageType::IndexValueType IndexValueType; - typedef std::vector<IndexValueType> IndexListType; + typedef typename ImageType::IndexValueType IndexValueType; + typedef std::vector<IndexValueType> IndexListType; // Batch size itkSetMacro(BatchSize, IndexValueType); @@ -98,29 +96,36 @@ public: protected: TensorflowMultisourceModelLearningBase(); - virtual ~TensorflowMultisourceModelLearningBase() {}; + virtual ~TensorflowMultisourceModelLearningBase(){}; - virtual void GenerateOutputInformation(void); + virtual void + GenerateOutputInformation(void) override; - virtual void GenerateInputRequestedRegion(); + virtual void + GenerateInputRequestedRegion(); - virtual void GenerateData(); + virtual void + GenerateData(); - virtual void PopulateInputTensors(DictType & inputs, const IndexValueType & sampleStart, - const IndexValueType & batchSize, const IndexListType & order); + virtual void + PopulateInputTensors(DictType & inputs, + const IndexValueType & sampleStart, + const IndexValueType & batchSize, + const IndexListType & order); - virtual void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, - const IndexValueType & batchSize) = 0; + virtual void + ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, const IndexValueType & batchSize) = 0; private: - TensorflowMultisourceModelLearningBase(const Self&); //purposely not implemented - void operator=(const Self&); //purposely not implemented + TensorflowMultisourceModelLearningBase(const Self &); // purposely not implemented + void + operator=(const Self &); // purposely not implemented - unsigned int m_BatchSize; // Batch size - bool m_UseStreaming; // Use streaming on/off + unsigned int m_BatchSize; // Batch size + bool m_UseStreaming; // Use streaming on/off // Read only - IndexValueType m_NumberOfSamples; // Number of samples + IndexValueType m_NumberOfSamples; // Number of samples }; // end class diff --git a/include/otbTensorflowMultisourceModelLearningBase.hxx b/include/otbTensorflowMultisourceModelLearningBase.hxx index 353478292169c211bc1992549a3b814212f9b59f..bfa26d4dc3789a18879891fc9df9581a03fb4d1a 100644 --- a/include/otbTensorflowMultisourceModelLearningBase.hxx +++ b/include/otbTensorflowMultisourceModelLearningBase.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,39 +18,38 @@ namespace otb { template <class TInputImage> -TensorflowMultisourceModelLearningBase<TInputImage> -::TensorflowMultisourceModelLearningBase(): m_BatchSize(100), -m_UseStreaming(false), m_NumberOfSamples(0) - { - } +TensorflowMultisourceModelLearningBase<TInputImage>::TensorflowMultisourceModelLearningBase() + : m_BatchSize(100) + , m_UseStreaming(false) + , m_NumberOfSamples(0) +{} template <class TInputImage> void -TensorflowMultisourceModelLearningBase<TInputImage> -::GenerateOutputInformation() - { +TensorflowMultisourceModelLearningBase<TInputImage>::GenerateOutputInformation() +{ Superclass::GenerateOutputInformation(); // Set an empty output buffered region ImageType * outputPtr = this->GetOutput(); - RegionType nullRegion; + RegionType nullRegion; nullRegion.GetModifiableSize().Fill(1); outputPtr->SetNumberOfComponentsPerPixel(1); - outputPtr->SetLargestPossibleRegion( nullRegion ); + outputPtr->SetLargestPossibleRegion(nullRegion); // Count the number of samples m_NumberOfSamples = 0; - for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { + for (unsigned int i = 0; i < this->GetNumberOfInputs(); i++) + { // Input image pointer - ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); + ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i)); // Make sure input is available - if ( inputPtr.IsNull() ) - { + if (inputPtr.IsNull()) + { itkExceptionMacro(<< "Input " << i << " is null!"); - } + } // Update input information inputPtr->UpdateOutputInformation(); @@ -63,67 +62,62 @@ TensorflowMultisourceModelLearningBase<TInputImage> // Check size X if (inputPatchSize[0] != reqRegion.GetSize(0)) - itkExceptionMacro("Patch size for input " << i - << " is " << inputPatchSize - << " but input patches image size is " << reqRegion.GetSize()); + itkExceptionMacro("Patch size for input " << i << " is " << inputPatchSize << " but input patches image size is " + << reqRegion.GetSize()); // Check size Y if (reqRegion.GetSize(1) % inputPatchSize[1] != 0) itkExceptionMacro("Input patches image must have a number of rows which is " - << "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1) - << " rows but patch size Y is " << inputPatchSize[1] << " for input " << i); + << "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1) + << " rows but patch size Y is " << inputPatchSize[1] << " for input " << i); // Get the batch size const IndexValueType currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1]; // Check the consistency with other inputs if (m_NumberOfSamples == 0) - { + { m_NumberOfSamples = currNumberOfSamples; - } + } else if (m_NumberOfSamples != currNumberOfSamples) - { - itkGenericExceptionMacro("Batch size of input " << (i-1) - << " was " << m_NumberOfSamples - << " but input " << i - << " has a batch size of " << currNumberOfSamples ); - } - } // next input - } + { + itkGenericExceptionMacro("Batch size of input " << (i - 1) << " was " << m_NumberOfSamples << " but input " << i + << " has a batch size of " << currNumberOfSamples); + } + } // next input +} template <class TInputImage> void -TensorflowMultisourceModelLearningBase<TInputImage> -::GenerateInputRequestedRegion() - { +TensorflowMultisourceModelLearningBase<TInputImage>::GenerateInputRequestedRegion() +{ Superclass::GenerateInputRequestedRegion(); // For each image, set the requested region RegionType nullRegion; - for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i) - { - ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) ); + for (unsigned int i = 0; i < this->GetNumberOfInputs(); ++i) + { + ImageType * inputImage = static_cast<ImageType *>(Superclass::ProcessObject::GetInput(i)); // If the streaming is enabled, we don't read the full image if (m_UseStreaming) - { + { inputImage->SetRequestedRegion(nullRegion); - } + } else - { + { inputImage->SetRequestedRegion(inputImage->GetLargestPossibleRegion()); - } - } // next image - } + } + } // next image +} /** * */ template <class TInputImage> void -TensorflowMultisourceModelLearningBase<TInputImage> -::GenerateData() - { +TensorflowMultisourceModelLearningBase<TInputImage>::GenerateData() +{ // Batches loop const IndexValueType nBatches = std::ceil(m_NumberOfSamples / m_BatchSize); @@ -131,15 +125,15 @@ TensorflowMultisourceModelLearningBase<TInputImage> itk::ProgressReporter progress(this, 0, nBatches); - for (IndexValueType batch = 0 ; batch < nBatches ; batch++) - { + for (IndexValueType batch = 0; batch < nBatches; batch++) + { // Feed dict DictType inputs; // Batch start and size const IndexValueType sampleStart = batch * m_BatchSize; - IndexValueType batchSize = m_BatchSize; + IndexValueType batchSize = m_BatchSize; if (rest != 0 && batch == nBatches - 1) { batchSize = rest; @@ -149,40 +143,40 @@ TensorflowMultisourceModelLearningBase<TInputImage> this->ProcessBatch(inputs, sampleStart, batchSize); progress.CompletedPixel(); - } // Next batch - - } + } // Next batch +} template <class TInputImage> void -TensorflowMultisourceModelLearningBase<TInputImage> -::PopulateInputTensors(DictType & inputs, const IndexValueType & sampleStart, - const IndexValueType & batchSize, const IndexListType & order) - { +TensorflowMultisourceModelLearningBase<TInputImage>::PopulateInputTensors(DictType & inputs, + const IndexValueType & sampleStart, + const IndexValueType & batchSize, + const IndexListType & order) +{ const bool reorder = order.size(); // Populate input tensors - for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { + for (unsigned int i = 0; i < this->GetNumberOfInputs(); i++) + { // Input image pointer - ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); + ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i)); // Patch size of tensor #i const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i); // Create the tensor for the batch - const tensorflow::int64 sz_n = batchSize; - const tensorflow::int64 sz_y = inputPatchSize[1]; - const tensorflow::int64 sz_x = inputPatchSize[0]; - const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); - const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c}); - tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape); + const tensorflow::int64 sz_n = batchSize; + const tensorflow::int64 sz_y = inputPatchSize[1]; + const tensorflow::int64 sz_x = inputPatchSize[0]; + const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); + const tensorflow::TensorShape inputTensorShape({ sz_n, sz_y, sz_x, sz_c }); + tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape); // Populate the tensor - for (IndexValueType elem = 0 ; elem < batchSize ; elem++) - { + for (IndexValueType elem = 0; elem < batchSize; elem++) + { const tensorflow::uint64 samplePos = sampleStart + elem; - IndexType start; + IndexType start; start[0] = 0; if (reorder) { @@ -190,7 +184,8 @@ TensorflowMultisourceModelLearningBase<TInputImage> } else { - start[1] = samplePos * sz_y;; + start[1] = samplePos * sz_y; + ; } RegionType patchRegion(start, inputPatchSize); if (m_UseStreaming) @@ -198,14 +193,14 @@ TensorflowMultisourceModelLearningBase<TInputImage> // If streaming is enabled, we need to explicitly propagate requested region tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion); } - tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem ); - } + tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem); + } // Input #i : the tensor of patches (aka the batch) DictElementType input = { this->GetInputPlaceholders()[i], inputTensor }; inputs.push_back(input); - } // next input tensor - } + } // next input tensor +} } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelTrain.h b/include/otbTensorflowMultisourceModelTrain.h index 8f8983ca0acd20533d9a94936f4d4f1bbec175ea..694f09e0b0ebfdd65305432a602e9f3908c8eadf 100644 --- a/include/otbTensorflowMultisourceModelTrain.h +++ b/include/otbTensorflowMultisourceModelTrain.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -34,11 +34,9 @@ namespace otb * \ingroup OTBTensorflow */ template <class TInputImage> -class ITK_EXPORT TensorflowMultisourceModelTrain : -public TensorflowMultisourceModelLearningBase<TInputImage> +class ITK_EXPORT TensorflowMultisourceModelTrain : public TensorflowMultisourceModelLearningBase<TInputImage> { public: - /** Standard class typedefs. */ typedef TensorflowMultisourceModelTrain Self; typedef TensorflowMultisourceModelLearningBase<TInputImage> Superclass; @@ -52,25 +50,27 @@ public: itkTypeMacro(TensorflowMultisourceModelTrain, TensorflowMultisourceModelLearningBase); /** Superclass typedefs */ - typedef typename Superclass::DictType DictType; - typedef typename Superclass::TensorListType TensorListType; - typedef typename Superclass::IndexValueType IndexValueType; - typedef typename Superclass::IndexListType IndexListType; + typedef typename Superclass::DictType DictType; + typedef typename Superclass::TensorListType TensorListType; + typedef typename Superclass::IndexValueType IndexValueType; + typedef typename Superclass::IndexListType IndexListType; protected: TensorflowMultisourceModelTrain(); - virtual ~TensorflowMultisourceModelTrain() {}; + virtual ~TensorflowMultisourceModelTrain(){}; - virtual void GenerateData(); - virtual void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, - const IndexValueType & batchSize); + virtual void + GenerateData(); + virtual void + ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, const IndexValueType & batchSize); private: - TensorflowMultisourceModelTrain(const Self&); //purposely not implemented - void operator=(const Self&); //purposely not implemented + TensorflowMultisourceModelTrain(const Self &); // purposely not implemented + void + operator=(const Self &); // purposely not implemented - IndexListType m_RandomIndices; // Reordered indices + IndexListType m_RandomIndices; // Reordered indices }; // end class diff --git a/include/otbTensorflowMultisourceModelTrain.hxx b/include/otbTensorflowMultisourceModelTrain.hxx index e7b68dac9567ebb73065a368cee3cc8e4ef94992..46bc2d7bd22cab4a90a40131436bd428dc77aff9 100644 --- a/include/otbTensorflowMultisourceModelTrain.hxx +++ b/include/otbTensorflowMultisourceModelTrain.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,37 +18,33 @@ namespace otb { template <class TInputImage> -TensorflowMultisourceModelTrain<TInputImage> -::TensorflowMultisourceModelTrain() - { - } +TensorflowMultisourceModelTrain<TInputImage>::TensorflowMultisourceModelTrain() +{} template <class TInputImage> void -TensorflowMultisourceModelTrain<TInputImage> -::GenerateData() - { +TensorflowMultisourceModelTrain<TInputImage>::GenerateData() +{ // Initial sequence 1...N m_RandomIndices.resize(this->GetNumberOfSamples()); - std::iota (std::begin(m_RandomIndices), std::end(m_RandomIndices), 0); + std::iota(std::begin(m_RandomIndices), std::end(m_RandomIndices), 0); // Shuffle the sequence std::random_device rd; - std::mt19937 g(rd()); + std::mt19937 g(rd()); std::shuffle(m_RandomIndices.begin(), m_RandomIndices.end(), g); // Call the generic method Superclass::GenerateData(); - - } +} template <class TInputImage> void -TensorflowMultisourceModelTrain<TInputImage> -::ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, - const IndexValueType & batchSize) - { +TensorflowMultisourceModelTrain<TInputImage>::ProcessBatch(DictType & inputs, + const IndexValueType & sampleStart, + const IndexValueType & batchSize) +{ // Populate input tensors this->PopulateInputTensors(inputs, sampleStart, batchSize, m_RandomIndices); @@ -57,12 +53,11 @@ TensorflowMultisourceModelTrain<TInputImage> this->RunSession(inputs, outputs); // Display outputs tensors - for (auto& o: outputs) + for (auto & o : outputs) { tf::PrintTensorInfos(o); } - - } +} } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelValidate.h b/include/otbTensorflowMultisourceModelValidate.h index f4a95406f06d5b5c1318cf1e5eb8ca16cdfb995e..54691747a7128d625c4a637972da67d02f11e1e1 100644 --- a/include/otbTensorflowMultisourceModelValidate.h +++ b/include/otbTensorflowMultisourceModelValidate.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -42,11 +42,9 @@ namespace otb * \ingroup OTBTensorflow */ template <class TInputImage> -class ITK_EXPORT TensorflowMultisourceModelValidate : -public TensorflowMultisourceModelLearningBase<TInputImage> +class ITK_EXPORT TensorflowMultisourceModelValidate : public TensorflowMultisourceModelLearningBase<TInputImage> { public: - /** Standard class typedefs. */ typedef TensorflowMultisourceModelValidate Self; typedef TensorflowMultisourceModelLearningBase<TInputImage> Superclass; @@ -60,20 +58,20 @@ public: itkTypeMacro(TensorflowMultisourceModelValidate, TensorflowMultisourceModelLearningBase); /** Images typedefs */ - typedef typename Superclass::ImageType ImageType; - typedef typename Superclass::ImagePointerType ImagePointerType; - typedef typename Superclass::RegionType RegionType; - typedef typename Superclass::SizeType SizeType; - typedef typename Superclass::IndexType IndexType; - typedef std::vector<ImagePointerType> ImageListType; + typedef typename Superclass::ImageType ImageType; + typedef typename Superclass::ImagePointerType ImagePointerType; + typedef typename Superclass::RegionType RegionType; + typedef typename Superclass::SizeType SizeType; + typedef typename Superclass::IndexType IndexType; + typedef std::vector<ImagePointerType> ImageListType; /* Typedefs for parameters */ - typedef typename Superclass::DictType DictType; - typedef typename Superclass::StringList StringList; - typedef typename Superclass::SizeListType SizeListType; - typedef typename Superclass::TensorListType TensorListType; - typedef typename Superclass::IndexValueType IndexValueType; - typedef typename Superclass::IndexListType IndexListType; + typedef typename Superclass::DictType DictType; + typedef typename Superclass::StringList StringList; + typedef typename Superclass::SizeListType SizeListType; + typedef typename Superclass::TensorListType TensorListType; + typedef typename Superclass::IndexValueType IndexValueType; + typedef typename Superclass::IndexListType IndexListType; /* Typedefs for validation */ typedef unsigned long CountValueType; @@ -87,36 +85,43 @@ public: typedef itk::ImageRegionConstIterator<ImageType> IteratorType; /** Set and Get the input references */ - virtual void SetInputReferences(ImageListType input); - ImagePointerType GetInputReference(unsigned int index); + virtual void + SetInputReferences(ImageListType input); + ImagePointerType + GetInputReference(unsigned int index); /** Get the confusion matrix */ - const ConfMatType GetConfusionMatrix(unsigned int target); + const ConfMatType + GetConfusionMatrix(unsigned int target); /** Get the map of classes matrix */ - const MapOfClassesType GetMapOfClasses(unsigned int target); + const MapOfClassesType + GetMapOfClasses(unsigned int target); protected: TensorflowMultisourceModelValidate(); - virtual ~TensorflowMultisourceModelValidate() {}; + virtual ~TensorflowMultisourceModelValidate(){}; - void GenerateOutputInformation(void); - void GenerateData(); - void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, - const IndexValueType & batchSize); + void + GenerateOutputInformation(void); + void + GenerateData(); + void + ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, const IndexValueType & batchSize); private: - TensorflowMultisourceModelValidate(const Self&); //purposely not implemented - void operator=(const Self&); //purposely not implemented + TensorflowMultisourceModelValidate(const Self &); // purposely not implemented + void + operator=(const Self &); // purposely not implemented - ImageListType m_References; // The references images + ImageListType m_References; // The references images // Read only - ConfMatListType m_ConfusionMatrices; // Confusion matrix - MapOfClassesListType m_MapsOfClasses; // Maps of classes + ConfMatListType m_ConfusionMatrices; // Confusion matrix + MapOfClassesListType m_MapsOfClasses; // Maps of classes // Internal - std::vector<MatMapType> m_ConfMatMaps; // Accumulators + std::vector<MatMapType> m_ConfMatMaps; // Accumulators }; // end class diff --git a/include/otbTensorflowMultisourceModelValidate.hxx b/include/otbTensorflowMultisourceModelValidate.hxx index c7673c8fc79abcfe322b02e5ee765aa9e560914e..a929aa884ea97ff3f80af968b275dc1029fb0de4 100644 --- a/include/otbTensorflowMultisourceModelValidate.hxx +++ b/include/otbTensorflowMultisourceModelValidate.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,158 +18,150 @@ namespace otb { template <class TInputImage> -TensorflowMultisourceModelValidate<TInputImage> -::TensorflowMultisourceModelValidate() - { - } +TensorflowMultisourceModelValidate<TInputImage>::TensorflowMultisourceModelValidate() +{} template <class TInputImage> void -TensorflowMultisourceModelValidate<TInputImage> -::GenerateOutputInformation() - { +TensorflowMultisourceModelValidate<TInputImage>::GenerateOutputInformation() +{ Superclass::GenerateOutputInformation(); // Check that there is some reference const unsigned int nbOfRefs = m_References.size(); if (nbOfRefs == 0) - { + { itkExceptionMacro("No reference is set"); - } + } // Check the number of references SizeListType outputPatchSizes = this->GetOutputExpressionFields(); if (nbOfRefs != outputPatchSizes.size()) - { - itkExceptionMacro("There is " << nbOfRefs << " references but only " << - outputPatchSizes.size() << " output patch sizes"); - } + { + itkExceptionMacro("There is " << nbOfRefs << " references but only " << outputPatchSizes.size() + << " output patch sizes"); + } // Check reference image infos - for (unsigned int i = 0 ; i < nbOfRefs ; i++) - { - const SizeType outputPatchSize = outputPatchSizes[i]; + for (unsigned int i = 0; i < nbOfRefs; i++) + { + const SizeType outputPatchSize = outputPatchSizes[i]; const RegionType refRegion = m_References[i]->GetLargestPossibleRegion(); if (refRegion.GetSize(0) != outputPatchSize[0]) - { - itkExceptionMacro("Reference image " << i << " width is " << refRegion.GetSize(0) << - " but patch size (x) is " << outputPatchSize[0]); - } + { + itkExceptionMacro("Reference image " << i << " width is " << refRegion.GetSize(0) << " but patch size (x) is " + << outputPatchSize[0]); + } if (refRegion.GetSize(1) != this->GetNumberOfSamples() * outputPatchSize[1]) - { - itkExceptionMacro("Reference image " << i << " height is " << refRegion.GetSize(1) << - " but patch size (y) is " << outputPatchSize[1] << - " which is not consistent with the number of samples (" << this->GetNumberOfSamples() << ")"); - } + { + itkExceptionMacro("Reference image " + << i << " height is " << refRegion.GetSize(1) << " but patch size (y) is " << outputPatchSize[1] + << " which is not consistent with the number of samples (" << this->GetNumberOfSamples() + << ")"); } - - } + } +} /* * Set the references images */ -template<class TInputImage> +template <class TInputImage> void -TensorflowMultisourceModelValidate<TInputImage> -::SetInputReferences(ImageListType input) - { +TensorflowMultisourceModelValidate<TInputImage>::SetInputReferences(ImageListType input) +{ m_References = input; - } +} /* * Retrieve the i-th reference image * An exception is thrown if it doesn't exist. */ -template<class TInputImage> +template <class TInputImage> typename TensorflowMultisourceModelValidate<TInputImage>::ImagePointerType -TensorflowMultisourceModelValidate<TInputImage> -::GetInputReference(unsigned int index) - { +TensorflowMultisourceModelValidate<TInputImage>::GetInputReference(unsigned int index) +{ if (m_References.size <= index || !m_References[index]) - { + { itkExceptionMacro("There is no input reference #" << index); - } + } return m_References[index]; - } +} /** * Perform the validation * The session is ran over the entire set of batches. - * Output is then validated agains the references images, + * Output is then validated against the references images, * and a confusion matrix is built. */ template <class TInputImage> void -TensorflowMultisourceModelValidate<TInputImage> -::GenerateData() - { +TensorflowMultisourceModelValidate<TInputImage>::GenerateData() +{ // Temporary images for outputs m_ConfusionMatrices.clear(); m_MapsOfClasses.clear(); m_ConfMatMaps.clear(); - for (auto const& ref: m_References) - { - (void) ref; + for (auto const & ref : m_References) + { + (void)ref; // New confusion matrix MatMapType mat; m_ConfMatMaps.push_back(mat); - } + } // Run all the batches Superclass::GenerateData(); // Compute confusion matrices - for (unsigned int i = 0 ; i < m_ConfMatMaps.size() ; i++) - { + for (unsigned int i = 0; i < m_ConfMatMaps.size(); i++) + { // Confusion matrix (map) for current target MatMapType mat = m_ConfMatMaps[i]; // List all values MapOfClassesType values; - LabelValueType curVal = 0; - for (auto const& ref: mat) - { + LabelValueType curVal = 0; + for (auto const & ref : mat) + { if (values.count(ref.first) == 0) - { + { values[ref.first] = curVal; curVal++; - } - for (auto const& in: ref.second) + } + for (auto const & in : ref.second) if (values.count(in.first) == 0) - { + { values[in.first] = curVal; curVal++; - } - } + } + } // Build the confusion matrix const LabelValueType nValues = values.size(); - ConfMatType matrix(nValues, nValues); + ConfMatType matrix(nValues, nValues); matrix.Fill(0); - for (auto const& ref: mat) - for (auto const& in: ref.second) + for (auto const & ref : mat) + for (auto const & in : ref.second) matrix[values[ref.first]][values[in.first]] = in.second; // Add the confusion matrix m_ConfusionMatrices.push_back(matrix); m_MapsOfClasses.push_back(values); - - } - - } + } +} template <class TInputImage> void -TensorflowMultisourceModelValidate<TInputImage> -::ProcessBatch(DictType & inputs, const IndexValueType & sampleStart, - const IndexValueType & batchSize) - { +TensorflowMultisourceModelValidate<TInputImage>::ProcessBatch(DictType & inputs, + const IndexValueType & sampleStart, + const IndexValueType & batchSize) +{ // Populate input tensors IndexListType empty; this->PopulateInputTensors(inputs, sampleStart, batchSize, empty); @@ -180,16 +172,16 @@ TensorflowMultisourceModelValidate<TInputImage> // Perform the validation if (outputs.size() != m_References.size()) - { - itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " << - "but only " << m_References.size() << " reference(s) set"); - } + { + itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " + << "but only " << m_References.size() << " reference(s) set"); + } SizeListType outputEFSizes = this->GetOutputExpressionFields(); - for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++) - { + for (unsigned int refIdx = 0; refIdx < outputs.size(); refIdx++) + { // Recopy the chunk const SizeType outputFOESize = outputEFSizes[refIdx]; - IndexType cpyStart; + IndexType cpyStart; cpyStart.Fill(0); IndexType refRegStart; refRegStart.Fill(0); @@ -216,31 +208,30 @@ TensorflowMultisourceModelValidate<TInputImage> IteratorType inIt(img, cpyRegion); IteratorType refIt(m_References[refIdx], refRegion); for (inIt.GoToBegin(), refIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++refIt) - { + { const int classIn = static_cast<LabelValueType>(inIt.Get()[0]); const int classRef = static_cast<LabelValueType>(refIt.Get()[0]); if (m_ConfMatMaps[refIdx].count(classRef) == 0) - { + { MapType newMap; newMap[classIn] = 1; m_ConfMatMaps[refIdx][classRef] = newMap; - } + } else - { + { if (m_ConfMatMaps[refIdx][classRef].count(classIn) == 0) - { + { m_ConfMatMaps[refIdx][classRef][classIn] = 1; - } + } else - { + { m_ConfMatMaps[refIdx][classRef][classIn]++; - } } } } - - } + } +} /* * Get the confusion matrix @@ -248,17 +239,17 @@ TensorflowMultisourceModelValidate<TInputImage> */ template <class TInputImage> const typename TensorflowMultisourceModelValidate<TInputImage>::ConfMatType -TensorflowMultisourceModelValidate<TInputImage> -::GetConfusionMatrix(unsigned int target) - { +TensorflowMultisourceModelValidate<TInputImage>::GetConfusionMatrix(unsigned int target) +{ if (target >= m_ConfusionMatrices.size()) - { - itkExceptionMacro("Unable to get confusion matrix #" << target << ". " << - "There is only " << m_ConfusionMatrices.size() << " available."); - } + { + itkExceptionMacro("Unable to get confusion matrix #" << target << ". " + << "There is only " << m_ConfusionMatrices.size() + << " available."); + } return m_ConfusionMatrices[target]; - } +} /* * Get the map of classes @@ -266,17 +257,17 @@ TensorflowMultisourceModelValidate<TInputImage> */ template <class TInputImage> const typename TensorflowMultisourceModelValidate<TInputImage>::MapOfClassesType -TensorflowMultisourceModelValidate<TInputImage> -::GetMapOfClasses(unsigned int target) - { +TensorflowMultisourceModelValidate<TInputImage>::GetMapOfClasses(unsigned int target) +{ if (target >= m_MapsOfClasses.size()) - { - itkExceptionMacro("Unable to get confusion matrix #" << target << ". " << - "There is only " << m_MapsOfClasses.size() << " available."); - } + { + itkExceptionMacro("Unable to get confusion matrix #" << target << ". " + << "There is only " << m_MapsOfClasses.size() + << " available."); + } return m_MapsOfClasses[target]; - } +} } // end namespace otb diff --git a/include/otbTensorflowSampler.h b/include/otbTensorflowSampler.h index d71eba7af1fc6a9d4e24405e0c087c520d456d9d..4fae38e75245ca417c638105379ec7ff7dddf6dd 100644 --- a/include/otbTensorflowSampler.h +++ b/include/otbTensorflowSampler.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -52,16 +52,14 @@ namespace otb * \ingroup OTBTensorflow */ template <class TInputImage, class TVectorData> -class ITK_EXPORT TensorflowSampler : -public itk::ProcessObject +class ITK_EXPORT TensorflowSampler : public itk::ProcessObject { public: - /** Standard class typedefs. */ - typedef TensorflowSampler Self; - typedef itk::ProcessObject Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef TensorflowSampler Self; + typedef itk::ProcessObject Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Method for creation through the object factory. */ itkNewMacro(Self); @@ -70,33 +68,28 @@ public: itkTypeMacro(TensorflowSampler, itk::ProcessObject); /** Images typedefs */ - typedef TInputImage ImageType; - typedef typename TInputImage::Pointer ImagePointerType; - typedef typename TInputImage::InternalPixelType InternalPixelType; - typedef typename TInputImage::PixelType PixelType; - typedef typename TInputImage::RegionType RegionType; - typedef typename TInputImage::PointType PointType; - typedef typename TInputImage::SizeType SizeType; - typedef typename TInputImage::IndexType IndexType; - typedef typename otb::MultiChannelExtractROI<InternalPixelType, - InternalPixelType> ExtractROIMultiFilterType; - typedef typename ExtractROIMultiFilterType::Pointer - ExtractROIMultiFilterPointerType; - typedef typename std::vector<ImagePointerType> ImagePointerListType; - typedef typename std::vector<SizeType> SizeListType; - typedef typename itk::ImageRegionConstIterator<ImageType> - IteratorType; + typedef TInputImage ImageType; + typedef typename TInputImage::Pointer ImagePointerType; + typedef typename TInputImage::InternalPixelType InternalPixelType; + typedef typename TInputImage::PixelType PixelType; + typedef typename TInputImage::RegionType RegionType; + typedef typename TInputImage::PointType PointType; + typedef typename TInputImage::SizeType SizeType; + typedef typename TInputImage::IndexType IndexType; + typedef typename otb::MultiChannelExtractROI<InternalPixelType, InternalPixelType> ExtractROIMultiFilterType; + typedef typename ExtractROIMultiFilterType::Pointer ExtractROIMultiFilterPointerType; + typedef typename std::vector<ImagePointerType> ImagePointerListType; + typedef typename std::vector<SizeType> SizeListType; + typedef typename itk::ImageRegionConstIterator<ImageType> IteratorType; /** Vector data typedefs */ - typedef TVectorData VectorDataType; - typedef typename VectorDataType::Pointer VectorDataPointer; - typedef typename VectorDataType::DataTreeType DataTreeType; - typedef typename itk::PreOrderTreeIterator<DataTreeType> - TreeIteratorType; - typedef typename VectorDataType::DataNodeType DataNodeType; - typedef typename DataNodeType::Pointer DataNodePointer; - typedef typename DataNodeType::PolygonListPointerType - PolygonListPointerType; + typedef TVectorData VectorDataType; + typedef typename VectorDataType::Pointer VectorDataPointer; + typedef typename VectorDataType::DataTreeType DataTreeType; + typedef typename itk::PreOrderTreeIterator<DataTreeType> TreeIteratorType; + typedef typename VectorDataType::DataNodeType DataNodeType; + typedef typename DataNodeType::Pointer DataNodePointer; + typedef typename DataNodeType::PolygonListPointerType PolygonListPointerType; /** Set / get parameters */ itkSetMacro(Field, std::string); @@ -107,15 +100,18 @@ public: itkGetConstMacro(InputVectorData, VectorDataPointer); /** Set / get image */ - virtual void PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize, InternalPixelType nodataval); - const ImageType* GetInput(unsigned int index); + virtual void + PushBackInputWithPatchSize(const ImageType * input, SizeType & patchSize, InternalPixelType nodataval); + const ImageType * + GetInput(unsigned int index); /** Set / get no-data related parameters */ itkSetMacro(RejectPatchesWithNodata, bool); itkGetMacro(RejectPatchesWithNodata, bool); /** Do the real work */ - virtual void Update(); + virtual void + Update(); /** Get outputs */ itkGetMacro(OutputPatchImages, ImagePointerListType); @@ -125,18 +121,21 @@ public: protected: TensorflowSampler(); - virtual ~TensorflowSampler() {}; + virtual ~TensorflowSampler(){}; - virtual void ResizeImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples); - virtual void AllocateImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples, unsigned int nbComponents); + virtual void + ResizeImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples); + virtual void + AllocateImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples, unsigned int nbComponents); private: - TensorflowSampler(const Self&); //purposely not implemented - void operator=(const Self&); //purposely not implemented + TensorflowSampler(const Self &); // purposely not implemented + void + operator=(const Self &); // purposely not implemented - std::string m_Field; - SizeListType m_PatchSizes; - VectorDataPointer m_InputVectorData; + std::string m_Field; + SizeListType m_PatchSizes; + VectorDataPointer m_InputVectorData; // Read only ImagePointerListType m_OutputPatchImages; @@ -146,7 +145,7 @@ private: // No data stuff std::vector<InternalPixelType> m_NoDataValues; - bool m_RejectPatchesWithNodata; + bool m_RejectPatchesWithNodata; }; // end class diff --git a/include/otbTensorflowSampler.hxx b/include/otbTensorflowSampler.hxx index 9611d93453b1040aa23602a6d24f7159969902eb..77558c7ba08c6dc75ce8ced1d389a537150a68f7 100644 --- a/include/otbTensorflowSampler.hxx +++ b/include/otbTensorflowSampler.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,36 +18,35 @@ namespace otb { template <class TInputImage, class TVectorData> -TensorflowSampler<TInputImage, TVectorData> -::TensorflowSampler() - { +TensorflowSampler<TInputImage, TVectorData>::TensorflowSampler() +{ m_NumberOfAcceptedSamples = 0; m_NumberOfRejectedSamples = 0; m_RejectPatchesWithNodata = false; - } +} template <class TInputImage, class TVectorData> void -TensorflowSampler<TInputImage, TVectorData> -::PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize, InternalPixelType nodataval) - { - this->ProcessObject::PushBackInput(const_cast<ImageType*>(input)); +TensorflowSampler<TInputImage, TVectorData>::PushBackInputWithPatchSize(const ImageType * input, + SizeType & patchSize, + InternalPixelType nodataval) +{ + this->ProcessObject::PushBackInput(const_cast<ImageType *>(input)); m_PatchSizes.push_back(patchSize); m_NoDataValues.push_back(nodataval); - } +} template <class TInputImage, class TVectorData> -const TInputImage* -TensorflowSampler<TInputImage, TVectorData> -::GetInput(unsigned int index) - { +const TInputImage * +TensorflowSampler<TInputImage, TVectorData>::GetInput(unsigned int index) +{ if (this->GetNumberOfInputs() < 1) { itkExceptionMacro("Input not set"); } - return static_cast<const ImageType*>(this->ProcessObject::GetInput(index)); - } + return static_cast<const ImageType *>(this->ProcessObject::GetInput(index)); +} /** @@ -55,9 +54,10 @@ TensorflowSampler<TInputImage, TVectorData> */ template <class TInputImage, class TVectorData> void -TensorflowSampler<TInputImage, TVectorData> -::ResizeImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples) - { +TensorflowSampler<TInputImage, TVectorData>::ResizeImage(ImagePointerType & image, + SizeType & patchSize, + unsigned int nbSamples) +{ // New image region RegionType region; region.SetSize(0, patchSize[0]); @@ -71,16 +71,18 @@ TensorflowSampler<TInputImage, TVectorData> // Assign image = resizer->GetOutput(); - } +} /** * Allocate an image given a patch size and a number of samples */ template <class TInputImage, class TVectorData> void -TensorflowSampler<TInputImage, TVectorData> -::AllocateImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples, unsigned int nbComponents) - { +TensorflowSampler<TInputImage, TVectorData>::AllocateImage(ImagePointerType & image, + SizeType & patchSize, + unsigned int nbSamples, + unsigned int nbComponents) +{ // Image region RegionType region; region.SetSize(0, patchSize[0]); @@ -91,16 +93,15 @@ TensorflowSampler<TInputImage, TVectorData> image->SetNumberOfComponentsPerPixel(nbComponents); image->SetRegions(region); image->Allocate(); - } +} /** * Do the work */ template <class TInputImage, class TVectorData> void -TensorflowSampler<TInputImage, TVectorData> -::Update() - { +TensorflowSampler<TInputImage, TVectorData>::Update() +{ // Check number of inputs if (this->GetNumberOfInputs() != m_PatchSizes.size()) @@ -109,8 +110,8 @@ TensorflowSampler<TInputImage, TVectorData> } // Count points - unsigned int nTotal = 0; - unsigned int geomId = 0; + unsigned int nTotal = 0; + unsigned int geomId = 0; TreeIteratorType itVector(m_InputVectorData->GetDataTree()); itVector.GoToBegin(); while (!itVector.IsAtEnd()) @@ -146,7 +147,7 @@ TensorflowSampler<TInputImage, TVectorData> const unsigned int nbInputs = this->GetNumberOfInputs(); m_OutputPatchImages.clear(); m_OutputPatchImages.reserve(nbInputs); - for (unsigned int i = 0 ; i < nbInputs ; i++) + for (unsigned int i = 0; i < nbInputs; i++) { ImagePointerType newImage; AllocateImage(newImage, m_PatchSizes[i], nTotal, GetInput(i)->GetNumberOfComponentsPerPixel()); @@ -154,13 +155,13 @@ TensorflowSampler<TInputImage, TVectorData> m_OutputPatchImages.push_back(newImage); } - itk::ProgressReporter progess(this, 0, nTotal); + itk::ProgressReporter progress(this, 0, nTotal); // Iterate on the vector data itVector.GoToBegin(); unsigned long count = 0; unsigned long rejected = 0; - IndexType labelIndex; + IndexType labelIndex; labelIndex[0] = 0; PixelType labelPix; labelPix.SetSize(1); @@ -169,13 +170,13 @@ TensorflowSampler<TInputImage, TVectorData> if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder()) { DataNodePointer currentGeometry = itVector.Get(); - PointType point = currentGeometry->GetPoint(); + PointType point = currentGeometry->GetPoint(); // Get the label value labelPix[0] = static_cast<InternalPixelType>(currentGeometry->GetFieldAsInt(m_Field)); bool hasBeenSampled = true; - for (unsigned int i = 0 ; i < nbInputs ; i++) + for (unsigned int i = 0; i < nbInputs; i++) { // Get input ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i)); @@ -188,7 +189,7 @@ TensorflowSampler<TInputImage, TVectorData> } // Check if the sampled patch contains a no-data value if (m_RejectPatchesWithNodata && hasBeenSampled) - { + { IndexType outIndex; outIndex[0] = 0; outIndex[1] = count * m_PatchSizes[i][1]; @@ -196,13 +197,13 @@ TensorflowSampler<TInputImage, TVectorData> IteratorType it(m_OutputPatchImages[i], region); for (it.GoToBegin(); !it.IsAtEnd(); ++it) - { + { PixelType pix = it.Get(); - for (unsigned int band = 0 ; band < pix.Size() ; band++) + for (unsigned int band = 0; band < pix.Size(); band++) if (pix[band] == m_NoDataValues[i]) hasBeenSampled = false; - } } + } } // Next input if (hasBeenSampled) { @@ -218,9 +219,8 @@ TensorflowSampler<TInputImage, TVectorData> rejected++; } - // Update progres - progess.CompletedPixel(); - + // Update progress + progress.CompletedPixel(); } ++itVector; @@ -228,7 +228,7 @@ TensorflowSampler<TInputImage, TVectorData> // Resize output images ResizeImage(m_OutputLabelImage, labelPatchSize, count); - for (unsigned int i = 0 ; i < nbInputs ; i++) + for (unsigned int i = 0; i < nbInputs; i++) { ResizeImage(m_OutputPatchImages[i], m_PatchSizes[i], count); } @@ -236,8 +236,7 @@ TensorflowSampler<TInputImage, TVectorData> // Update number of samples produced m_NumberOfAcceptedSamples = count; m_NumberOfRejectedSamples = rejected; - - } +} } // end namespace otb diff --git a/include/otbTensorflowSamplingUtils.cxx b/include/otbTensorflowSamplingUtils.cxx index 5a8b8e3c03225c855479adc07f371b730cec787e..db4d9ea01d718c5957a9080486dcb18b83097995 100644 --- a/include/otbTensorflowSamplingUtils.cxx +++ b/include/otbTensorflowSamplingUtils.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -19,13 +19,15 @@ namespace tf // // Update the distribution of the patch located at the specified location // -template<class TImage, class TDistribution> -bool UpdateDistributionFromPatch(const typename TImage::Pointer inPtr, - typename TImage::PointType point, typename TImage::SizeType patchSize, - TDistribution & dist) +template <class TImage, class TDistribution> +bool +UpdateDistributionFromPatch(const typename TImage::Pointer inPtr, + typename TImage::PointType point, + typename TImage::SizeType patchSize, + TDistribution & dist) { typename TImage::IndexType index; - bool canTransform = inPtr->TransformPhysicalPointToIndex(point, index); + bool canTransform = inPtr->TransformPhysicalPointToIndex(point, index); if (canTransform) { index[0] -= patchSize[0] / 2; @@ -38,7 +40,7 @@ bool UpdateDistributionFromPatch(const typename TImage::Pointer inPtr, // Fill patch PropagateRequestedRegion<TImage>(inPtr, inPatchRegion); - typename itk::ImageRegionConstIterator<TImage> inIt (inPtr, inPatchRegion); + typename itk::ImageRegionConstIterator<TImage> inIt(inPtr, inPatchRegion); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { dist.Update(inIt.Get()); @@ -47,7 +49,6 @@ bool UpdateDistributionFromPatch(const typename TImage::Pointer inPtr, } } return false; - } diff --git a/include/otbTensorflowSamplingUtils.h b/include/otbTensorflowSamplingUtils.h index 93879301cf5341a2d479e8752d6af3c4d9ed6ee5..846b71318e57e2f64b7b5d582c4a85223bd809b3 100644 --- a/include/otbTensorflowSamplingUtils.h +++ b/include/otbTensorflowSamplingUtils.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -20,83 +20,89 @@ namespace otb namespace tf { -template<class TImage> +template <class TImage> class Distribution { public: typedef typename TImage::PixelType ValueType; - typedef vnl_vector<float> CountsType; - - Distribution(unsigned int nClasses){ - m_NbOfClasses = nClasses; - m_Dist = CountsType(nClasses, 0); - - } - Distribution(unsigned int nClasses, float fillValue){ - m_NbOfClasses = nClasses; - m_Dist = CountsType(nClasses, fillValue); - - } - Distribution(){ - m_NbOfClasses = 2; - m_Dist = CountsType(m_NbOfClasses, 0); - } - Distribution(const Distribution & other){ - m_Dist = other.Get(); - m_NbOfClasses = m_Dist.size(); - } - ~Distribution(){} - - void Update(const typename TImage::PixelType & pixel) + typedef vnl_vector<float> CountsType; + + explicit Distribution(unsigned int nClasses) + : m_NbOfClasses(nClasses) + , m_Dist(CountsType(nClasses, 0)) + {} + Distribution(unsigned int nClasses, float fillValue) + : m_NbOfClasses(nClasses) + , m_Dist(CountsType(nClasses, fillValue)) + {} + Distribution() + : m_NbOfClasses(2) + , m_Dist(CountsType(m_NbOfClasses, 0)) + {} + Distribution(const Distribution & other) + : m_Dist(other.Get()) + , m_NbOfClasses(m_Dist.size()) + {} + ~Distribution() {} + + void + Update(const typename TImage::PixelType & pixel) { m_Dist[pixel]++; } - void Update(const Distribution & other) + void + Update(const Distribution & other) { const CountsType otherDist = other.Get(); - for (unsigned int c = 0 ; c < m_NbOfClasses ; c++) + for (unsigned int c = 0; c < m_NbOfClasses; c++) m_Dist[c] += otherDist[c]; } - CountsType Get() const + CountsType + Get() const { return m_Dist; } - CountsType GetNormalized() const + CountsType + GetNormalized() const { - const float invNorm = 1.0 / std::sqrt(dot_product(m_Dist, m_Dist)); + const float invNorm = 1.0 / std::sqrt(dot_product(m_Dist, m_Dist)); const CountsType normalizedDist = invNorm * m_Dist; return normalizedDist; } - float Cosinus(const Distribution & other) const + float + Cosinus(const Distribution & other) const { return dot_product(other.GetNormalized(), GetNormalized()); } - std::string ToString() + std::string + ToString() { std::stringstream ss; ss << "\n"; - for (unsigned int c = 0 ; c < m_NbOfClasses ; c++) + for (unsigned int c = 0; c < m_NbOfClasses; c++) ss << "\tClass #" << c << " : " << m_Dist[c] << "\n"; return ss.str(); } private: unsigned int m_NbOfClasses; - CountsType m_Dist; + CountsType m_Dist; }; // Update the distribution of the patch located at the specified location -template<class TImage, class TDistribution> -bool UpdateDistributionFromPatch(const typename TImage::Pointer inPtr, - typename TImage::PointType point, typename TImage::SizeType patchSize, - TDistribution & dist); - -} // namesapce tf +template <class TImage, class TDistribution> +bool +UpdateDistributionFromPatch(const typename TImage::Pointer inPtr, + typename TImage::PointType point, + typename TImage::SizeType patchSize, + TDistribution & dist); + +} // namespace tf } // namespace otb #include "otbTensorflowSamplingUtils.cxx" diff --git a/include/otbTensorflowSource.h b/include/otbTensorflowSource.h index f569c720f2143c81d7608072cd51f6eac746076d..9bbeed12fbe07820c1a59d125d0af5343dfd3492 100644 --- a/include/otbTensorflowSource.h +++ b/include/otbTensorflowSource.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -25,49 +25,47 @@ namespace otb { /* - * This is a simple helper to create images concatenation. + * This is a helper for images concatenation. * Images must have the same size. - * This is basically the common input type used in every OTB-TF applications. + * This is the common input type used in every OTB-TF applications. */ -template<class TImage> +template <class TImage> class TensorflowSource { public: /** Typedefs for images */ - typedef TImage FloatVectorImageType; - typedef typename FloatVectorImageType::Pointer FloatVectorImagePointerType; - typedef typename FloatVectorImageType::InternalPixelType InternalPixelType; - typedef otb::Image<InternalPixelType> FloatImageType; - typedef typename FloatImageType::SizeType SizeType; + typedef TImage FloatVectorImageType; + typedef typename FloatVectorImageType::Pointer FloatVectorImagePointerType; + typedef typename FloatVectorImageType::InternalPixelType InternalPixelType; + typedef otb::Image<InternalPixelType> FloatImageType; + typedef typename FloatImageType::SizeType SizeType; /** Typedefs for image concatenation */ - typedef otb::ImageList<FloatImageType> ImageListType; - typedef typename ImageListType::Pointer ImageListPointer; - typedef ImageListToVectorImageFilter<ImageListType, - FloatVectorImageType> ListConcatenerFilterType; - typedef typename ListConcatenerFilterType::Pointer ListConcatenerFilterPointer; - typedef MultiToMonoChannelExtractROI<InternalPixelType, - InternalPixelType> MultiToMonoChannelFilterType; - typedef ObjectList<MultiToMonoChannelFilterType> ExtractROIFilterListType; - typedef typename ExtractROIFilterListType::Pointer ExtractROIFilterListPointer; - typedef otb::MultiChannelExtractROI<InternalPixelType, - InternalPixelType> ExtractFilterType; - typedef otb::ObjectList<FloatVectorImageType> FloatVectorImageListType; + typedef otb::ImageList<FloatImageType> ImageListType; + typedef typename ImageListType::Pointer ImageListPointer; + typedef ImageListToVectorImageFilter<ImageListType, FloatVectorImageType> ListConcatenerFilterType; + typedef typename ListConcatenerFilterType::Pointer ListConcatenerFilterPointer; + typedef MultiToMonoChannelExtractROI<InternalPixelType, InternalPixelType> MultiToMonoChannelFilterType; + typedef ObjectList<MultiToMonoChannelFilterType> ExtractROIFilterListType; + typedef typename ExtractROIFilterListType::Pointer ExtractROIFilterListPointer; + typedef otb::MultiChannelExtractROI<InternalPixelType, InternalPixelType> ExtractFilterType; + typedef otb::ObjectList<FloatVectorImageType> FloatVectorImageListType; // Initialize the source - void Set(FloatVectorImageListType * inputList); + void + Set(FloatVectorImageListType * inputList); // Get the source output - FloatVectorImagePointerType Get(); + FloatVectorImagePointerType + Get(); - TensorflowSource(){}; - virtual ~TensorflowSource (){}; + TensorflowSource(); + virtual ~TensorflowSource(){}; private: ListConcatenerFilterPointer m_Concatener; // Mono-images stacker ImageListPointer m_List; // List of mono-images ExtractROIFilterListPointer m_ExtractorList; // Mono-images extractors - }; } // end namespace otb diff --git a/include/otbTensorflowSource.hxx b/include/otbTensorflowSource.hxx index bb3de775aa889b3fd0ffaa7324dc2f03ba5159e8..2e41253c69c70e328e9d6827ee1a5f5971c716cb 100644 --- a/include/otbTensorflowSource.hxx +++ b/include/otbTensorflowSource.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -17,6 +17,13 @@ namespace otb { +// +// Constructor +// +template <class TImage> +TensorflowSource<TImage>::TensorflowSource() +{} + // // Prepare the big stack of images // @@ -25,36 +32,35 @@ void TensorflowSource<TImage>::Set(FloatVectorImageListType * inputList) { // Create one stack for input images list - m_Concatener = ListConcatenerFilterType::New(); - m_List = ImageListType::New(); + m_Concatener = ListConcatenerFilterType::New(); + m_List = ImageListType::New(); m_ExtractorList = ExtractROIFilterListType::New(); // Split each input vector image into image // and generate an mono channel image list inputList->GetNthElement(0)->UpdateOutputInformation(); SizeType size = inputList->GetNthElement(0)->GetLargestPossibleRegion().GetSize(); - for( unsigned int i = 0; i < inputList->Size(); i++ ) + for (unsigned int i = 0; i < inputList->Size(); i++) { FloatVectorImagePointerType vectIm = inputList->GetNthElement(i); vectIm->UpdateOutputInformation(); - if( size != vectIm->GetLargestPossibleRegion().GetSize() ) + if (size != vectIm->GetLargestPossibleRegion().GetSize()) { itkGenericExceptionMacro("Input image size number " << i << " mismatch"); } - for( unsigned int j = 0; j < vectIm->GetNumberOfComponentsPerPixel(); j++) + for (unsigned int j = 0; j < vectIm->GetNumberOfComponentsPerPixel(); j++) { typename MultiToMonoChannelFilterType::Pointer extractor = MultiToMonoChannelFilterType::New(); - extractor->SetInput( vectIm ); - extractor->SetChannel( j+1 ); + extractor->SetInput(vectIm); + extractor->SetChannel(j + 1); extractor->UpdateOutputInformation(); - m_ExtractorList->PushBack( extractor ); - m_List->PushBack( extractor->GetOutput() ); + m_ExtractorList->PushBack(extractor); + m_List->PushBack(extractor->GetOutput()); } } - m_Concatener->SetInput( m_List ); + m_Concatener->SetInput(m_List); m_Concatener->UpdateOutputInformation(); - } // diff --git a/include/otbTensorflowStreamerFilter.h b/include/otbTensorflowStreamerFilter.h index eee73f875f47c7190fde3e3367b0af0381b51ddb..fa985d007a1040bc5325d3831724c4108316da5e 100644 --- a/include/otbTensorflowStreamerFilter.h +++ b/include/otbTensorflowStreamerFilter.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -26,12 +26,10 @@ namespace otb * \ingroup OTBTensorflow */ template <class TInputImage, class TOutputImage> -class ITK_EXPORT TensorflowStreamerFilter : -public itk::ImageToImageFilter<TInputImage, TOutputImage> +class ITK_EXPORT TensorflowStreamerFilter : public itk::ImageToImageFilter<TInputImage, TOutputImage> { public: - /** Standard class typedefs. */ typedef TensorflowStreamerFilter Self; typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass; @@ -51,24 +49,31 @@ public: typedef typename ImageType::SizeType SizeType; typedef typename Superclass::InputImageRegionType RegionType; - typedef TOutputImage OutputImageType; + typedef TOutputImage OutputImageType; itkSetMacro(OutputGridSize, SizeType); itkGetMacro(OutputGridSize, SizeType); protected: TensorflowStreamerFilter(); - virtual ~TensorflowStreamerFilter() {}; + virtual ~TensorflowStreamerFilter(){}; - virtual void UpdateOutputData(itk::DataObject *output){(void) output; this->GenerateData();} + virtual void + UpdateOutputData(itk::DataObject * output) + { + (void)output; + this->GenerateData(); + } - virtual void GenerateData(); + virtual void + GenerateData(); private: - TensorflowStreamerFilter(const Self&); //purposely not implemented - void operator=(const Self&); //purposely not implemented + TensorflowStreamerFilter(const Self &); // purposely not implemented + void + operator=(const Self &); // purposely not implemented - SizeType m_OutputGridSize; // Output grid size + SizeType m_OutputGridSize; // Output grid size }; // end class diff --git a/include/otbTensorflowStreamerFilter.hxx b/include/otbTensorflowStreamerFilter.hxx index 5e5622be542b4849f016e8d2e8494974c75fd6da..3aa1afca538841126700484f34f9261731b75d1e 100644 --- a/include/otbTensorflowStreamerFilter.hxx +++ b/include/otbTensorflowStreamerFilter.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -19,30 +19,28 @@ namespace otb { template <class TInputImage, class TOutputImage> -TensorflowStreamerFilter<TInputImage, TOutputImage> -::TensorflowStreamerFilter() - { +TensorflowStreamerFilter<TInputImage, TOutputImage>::TensorflowStreamerFilter() +{ m_OutputGridSize.Fill(1); - } +} /** * Compute the output image */ template <class TInputImage, class TOutputImage> void -TensorflowStreamerFilter<TInputImage, TOutputImage> -::GenerateData() - { +TensorflowStreamerFilter<TInputImage, TOutputImage>::GenerateData() +{ // Output pointer and requested region OutputImageType * outputPtr = this->GetOutput(); - const RegionType outputReqRegion = outputPtr->GetRequestedRegion(); + const RegionType outputReqRegion = outputPtr->GetRequestedRegion(); outputPtr->SetBufferedRegion(outputReqRegion); outputPtr->Allocate(); // Compute the aligned region RegionType region; - for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) - { + for (unsigned int dim = 0; dim < OutputImageType::ImageDimension; ++dim) + { // Get corners IndexValueType lower = outputReqRegion.GetIndex(dim); IndexValueType upper = lower + outputReqRegion.GetSize(dim); @@ -54,35 +52,34 @@ TensorflowStreamerFilter<TInputImage, TOutputImage> // Move corners to aligned positions lower -= deltaLo; if (deltaUp > 0) - { + { upper += m_OutputGridSize[dim] - deltaUp; - } + } // Update region region.SetIndex(dim, lower); region.SetSize(dim, upper - lower); - - } + } // Compute the number of subregions to process const unsigned int nbTilesX = region.GetSize(0) / m_OutputGridSize[0]; const unsigned int nbTilesY = region.GetSize(1) / m_OutputGridSize[1]; // Progress - itk::ProgressReporter progress(this, 0, nbTilesX*nbTilesY); + itk::ProgressReporter progress(this, 0, nbTilesX * nbTilesY); // For each tile, propagate the input region and recopy the output - ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) ); + ImageType * inputImage = static_cast<ImageType *>(Superclass::ProcessObject::GetInput(0)); unsigned int tx, ty; - RegionType subRegion; + RegionType subRegion; subRegion.SetSize(m_OutputGridSize); for (ty = 0; ty < nbTilesY; ty++) { - subRegion.SetIndex(1, ty*m_OutputGridSize[1] + region.GetIndex(1)); + subRegion.SetIndex(1, ty * m_OutputGridSize[1] + region.GetIndex(1)); for (tx = 0; tx < nbTilesX; tx++) { // Update the input subregion - subRegion.SetIndex(0, tx*m_OutputGridSize[0] + region.GetIndex(0)); + subRegion.SetIndex(0, tx * m_OutputGridSize[0] + region.GetIndex(0)); // The actual region to copy RegionType cpyRegion(subRegion); @@ -94,12 +91,12 @@ TensorflowStreamerFilter<TInputImage, TOutputImage> inputImage->UpdateOutputData(); // Copy the subregion to output - itk::ImageAlgorithm::Copy( inputImage, outputPtr, cpyRegion, cpyRegion ); + itk::ImageAlgorithm::Copy(inputImage, outputPtr, cpyRegion, cpyRegion); progress.CompletedPixel(); } } - } +} } // end namespace otb diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py index cbb72bb941f4fbe4c06af9e3baa5883ca3646887..117203bafd89bcbfaa272952323434dac4046a8b 100755 --- a/python/ckpt2savedmodel.py +++ b/python/ckpt2savedmodel.py @@ -2,8 +2,8 @@ # -*- coding: utf-8 -*- # ========================================================================== # -# Copyright 2018-2019 Remi Cresson (IRSTEA) -# Copyright 2020 Remi Cresson (INRAE) +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2021 INRAE # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,23 +18,37 @@ # limitations under the License. # # ==========================================================================*/ +""" +This application converts a checkpoint into a SavedModel, that can be used in +TensorflowModelTrain or TensorflowModelServe OTB applications. +This is intended to work mostly with tf.v1 models, since the models in tf.v2 +can be more conveniently exported as SavedModel (see how to build a model with +keras in Tensorflow 2). +""" import argparse from tricks import ckpt_to_savedmodel -# Parser -parser = argparse.ArgumentParser() -parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True) -parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+') -parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, - nargs='+') -parser.add_argument("--model", help="Output directory for SavedModel", required=True) -parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') -parser.set_defaults(clear_devices=False) -params = parser.parse_args() -if __name__ == "__main__": +def main(): + """ + Main function + """ + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True) + parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+') + parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, + nargs='+') + parser.add_argument("--model", help="Output directory for SavedModel", required=True) + parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') + parser.set_defaults(clear_devices=False) + params = parser.parse_args() + ckpt_to_savedmodel(ckpt_path=params.ckpt, inputs=params.inputs, outputs=params.outputs, savedmodel_path=params.model, clear_devices=params.clear_devices) + + +if __name__ == "__main__": + main() diff --git a/python/create_savedmodel_ienco-m3_patchbased.py b/python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py similarity index 99% rename from python/create_savedmodel_ienco-m3_patchbased.py rename to python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py index fdb772278bb0ffd9db4844be70a8eae7ef6ff8c6..2a3ad56f2e8ad0aea07ae34766944e43e580d6b0 100755 --- a/python/create_savedmodel_ienco-m3_patchbased.py +++ b/python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# ========================================================================== +# ========================================================================= # # Copyright 2018-2019 Remi Cresson, Dino Ienco (IRSTEA) # Copyright 2020-2021 Remi Cresson, Dino Ienco (INRAE) @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# ==========================================================================*/ +# ========================================================================= # Reference: # @@ -26,12 +26,12 @@ # Satellite Data Fusion. IEEE Journal of Selected Topics in Applied Earth # Observations and Remote Sensing, 11(12), 4939-4949. +import argparse from tricks import create_savedmodel import tensorflow.compat.v1 as tf import tensorflow.compat.v1.nn.rnn_cell as rnn -tf.disable_v2_behavior() -import argparse +tf.disable_v2_behavior() parser = argparse.ArgumentParser() parser.add_argument("--nunits", type=int, default=1024, help="number of units") @@ -63,7 +63,7 @@ def RnnAttention(x, nunits, nlayer, n_dims, n_timetamps, is_training_ph): cell = rnn.GRUCell(nunits) cells.append(cell) cell = tf.compat.v1.contrib.rnn.MultiRNNCell(cells) - # SIGNLE LAYER: single GRUCell, nunits hidden units each + # SINGLE LAYER: single GRUCell, nunits hidden units each else: cell = rnn.GRUCell(nunits) outputs, _ = tf.compat.v1.nn.static_rnn(cell, x, dtype="float32") diff --git a/python/create_savedmodel_maggiori17_fullyconv.py b/python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py similarity index 95% rename from python/create_savedmodel_maggiori17_fullyconv.py rename to python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py index 32843e764c22249c36a8089af132e453e42114e9..7c2bed5c40c537a55d2f39e904faa6f3c66e506f 100755 --- a/python/create_savedmodel_maggiori17_fullyconv.py +++ b/python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -#========================================================================== +# ========================================================================= # # Copyright 2018-2019 Remi Cresson (IRSTEA) # Copyright 2020-2021 Remi Cresson (INRAE) @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -#==========================================================================*/ +# ========================================================================= # Reference: # @@ -79,7 +79,7 @@ with tf.compat.v1.Graph().as_default(): activation=tf.nn.crelu) # Deconv = conv on the padded/strided input, that is an (5+1)*4 - deconv1 = tf.compat.v1.layers.conv2d_transpose(inputs=conv4, filters=1, strides=(4,4), kernel_size=[8, 8], + deconv1 = tf.compat.v1.layers.conv2d_transpose(inputs=conv4, filters=1, strides=(4, 4), kernel_size=[8, 8], padding="valid", activation=tf.nn.sigmoid) n = tf.shape(deconv1)[0] diff --git a/python/create_savedmodel_pxs_fcn.py b/python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py similarity index 100% rename from python/create_savedmodel_pxs_fcn.py rename to python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py diff --git a/python/create_savedmodel_simple_cnn.py b/python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py similarity index 100% rename from python/create_savedmodel_simple_cnn.py rename to python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py diff --git a/python/create_savedmodel_simple_fcn.py b/python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py similarity index 100% rename from python/create_savedmodel_simple_fcn.py rename to python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py diff --git a/python/otbtf.py b/python/otbtf.py index a925d3ea07b3a0d54bca439fd011be5ccb1f62a0..a23d5237b701b9ca245fa9b850162ca2ec52e064 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # ========================================================================== # -# Copyright 2018-2019 Remi Cresson (IRSTEA) -# Copyright 2020 Remi Cresson (INRAE) +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2021 INRAE # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,57 +17,57 @@ # limitations under the License. # # ==========================================================================*/ +""" +Contains stuff to help working with TensorFlow and geospatial data in the +OTBTF framework. +""" import threading import multiprocessing import time +import logging +from abc import ABC, abstractmethod import numpy as np import tensorflow as tf import gdal -import logging -from abc import ABC, abstractmethod -""" -------------------------------------------------------- Helpers -------------------------------------------------------- -""" +# ----------------------------------------------------- Helpers -------------------------------------------------------- def gdal_open(filename): """ Open a GDAL raster :param filename: raster file - :return: a GDAL ds instance + :return: a GDAL dataset instance """ - ds = gdal.Open(filename) - if ds is None: + gdal_ds = gdal.Open(filename) + if gdal_ds is None: raise Exception("Unable to open file {}".format(filename)) - return ds + return gdal_ds -def read_as_np_arr(ds, as_patches=True): +def read_as_np_arr(gdal_ds, as_patches=True): """ Read a GDAL raster as numpy array - :param ds: GDAL ds instance + :param gdal_ds: a GDAL dataset instance :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - buffer = ds.ReadAsArray() - szx = ds.RasterXSize + buffer = gdal_ds.ReadAsArray() + size_x = gdal_ds.RasterXSize if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: - n = 1 - szy = ds.RasterYSize + n_elems = 1 + size_y = gdal_ds.RasterYSize else: - n = int(ds.RasterYSize / szx) - szy = szx - return np.float32(buffer.reshape((n, szy, szx, ds.RasterCount))) + n_elems = int(gdal_ds.RasterYSize / size_x) + size_y = size_x + return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) -""" ----------------------------------------------------- Buffer class ------------------------------------------------------ -""" +# -------------------------------------------------- Buffer class ------------------------------------------------------ class Buffer: @@ -80,19 +80,27 @@ class Buffer: self.container = [] def size(self): + """ + Returns the buffer size + """ return len(self.container) - def add(self, x): - self.container.append(x) - assert (self.size() <= self.max_length) + def add(self, new_element): + """ + Add an element in the buffer + :param new_element: new element to add + """ + self.container.append(new_element) + assert self.size() <= self.max_length def is_complete(self): + """ + Return True if the buffer is at full capacity + """ return self.size() == self.max_length -""" ------------------------------------------------- PatchesReaderBase class ----------------------------------------------- -""" +# ---------------------------------------------- PatchesReaderBase class ----------------------------------------------- class PatchesReaderBase(ABC): @@ -106,7 +114,6 @@ class PatchesReaderBase(ABC): Return one sample. :return One sample instance, whatever the sample structure is (dict, numpy array, ...) """ - pass @abstractmethod def get_stats(self) -> dict: @@ -129,7 +136,6 @@ class PatchesReaderBase(ABC): "std": np.array([...])}, } """ - pass @abstractmethod def get_size(self): @@ -137,12 +143,9 @@ class PatchesReaderBase(ABC): Returns the total number of samples :return: number of samples (int) """ - pass -""" ------------------------------------------------ PatchesImagesReader class ---------------------------------------------- -""" +# --------------------------------------------- PatchesImagesReader class ---------------------------------------------- class PatchesImagesReader(PatchesReaderBase): @@ -174,66 +177,64 @@ class PatchesImagesReader(PatchesReaderBase): :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. """ - assert (len(filenames_dict.values()) > 0) + assert len(filenames_dict.values()) > 0 - # ds dict - self.ds = dict() - for src_key, src_filenames in filenames_dict.items(): - self.ds[src_key] = [] - for src_filename in src_filenames: - self.ds[src_key].append(gdal_open(src_filename)) + # gdal_ds dict + self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} - if len(set([len(ds_list) for ds_list in self.ds.values()])) != 1: + # check number of patches in each sources + if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1: raise Exception("Each source must have the same number of patches images") # streaming on/off self.use_streaming = use_streaming - # ds check - nb_of_patches = {key: 0 for key in self.ds} + # gdal_ds check + nb_of_patches = {key: 0 for key in self.gdal_ds} self.nb_of_channels = dict() - for src_key, ds_list in self.ds.items(): - for ds in ds_list: - nb_of_patches[src_key] += self._get_nb_of_patches(ds) + for src_key, ds_list in self.gdal_ds.items(): + for gdal_ds in ds_list: + nb_of_patches[src_key] += self._get_nb_of_patches(gdal_ds) if src_key not in self.nb_of_channels: - self.nb_of_channels[src_key] = ds.RasterCount + self.nb_of_channels[src_key] = gdal_ds.RasterCount else: - if self.nb_of_channels[src_key] != ds.RasterCount: + if self.nb_of_channels[src_key] != gdal_ds.RasterCount: raise Exception("All patches images from one source must have the same number of channels!" "Error happened for source: {}".format(src_key)) if len(set(nb_of_patches.values())) != 1: raise Exception("Sources must have the same number of patches! Number of patches: {}".format(nb_of_patches)) - # ds sizes - src_key_0 = list(self.ds)[0] # first key - self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.ds[src_key_0]] + # gdal_ds sizes + src_key_0 = list(self.gdal_ds)[0] # first key + self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.gdal_ds[src_key_0]] self.size = sum(self.ds_sizes) # if use_streaming is False, we store in memory all patches images if not self.use_streaming: - patches_list = {src_key: [read_as_np_arr(ds) for ds in self.ds[src_key]] for src_key in self.ds} - self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=-1) for src_key in self.ds} + patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} + self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds} def _get_ds_and_offset_from_index(self, index): offset = index - for i, ds_size in enumerate(self.ds_sizes): + idx = None + for idx, ds_size in enumerate(self.ds_sizes): if offset < ds_size: break offset -= ds_size - return i, offset + return idx, offset @staticmethod - def _get_nb_of_patches(ds): - return int(ds.RasterYSize / ds.RasterXSize) + def _get_nb_of_patches(gdal_ds): + return int(gdal_ds.RasterYSize / gdal_ds.RasterXSize) @staticmethod - def _read_extract_as_np_arr(ds, offset): - assert (ds is not None) - psz = ds.RasterXSize + def _read_extract_as_np_arr(gdal_ds, offset): + assert gdal_ds is not None + psz = gdal_ds.RasterXSize yoff = int(offset * psz) - assert (yoff + psz <= ds.RasterYSize) - buffer = ds.ReadAsArray(0, yoff, psz, psz) + assert yoff + psz <= gdal_ds.RasterYSize + buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) return np.float32(buffer) @@ -248,14 +249,14 @@ class PatchesImagesReader(PatchesReaderBase): ... "src_key_M": np.array((psz_y_M, psz_x_M, nb_ch_M))} """ - assert (0 <= index) - assert (index < self.size) + assert index >= 0 + assert index < self.size if not self.use_streaming: - res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.ds} + res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds} else: i, offset = self._get_ds_and_offset_from_index(index) - res = {src_key: self._read_extract_as_np_arr(self.ds[src_key][i], offset) for src_key in self.ds} + res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} return res @@ -278,7 +279,7 @@ class PatchesImagesReader(PatchesReaderBase): axis = (0, 1) # (row, col) def _filled(value): - return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.ds} + return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.gdal_ds} _maxs = _filled(0.0) _mins = _filled(float("inf")) @@ -298,17 +299,15 @@ class PatchesImagesReader(PatchesReaderBase): "max": _maxs[src_key], "mean": rsize * _sums[src_key], "std": np.sqrt(rsize * _sqsums[src_key] - np.square(rsize * _sums[src_key])) - } for src_key in self.ds} - logging.info("Stats: {}".format(stats)) + } for src_key in self.gdal_ds} + logging.info("Stats: {}", stats) return stats def get_size(self): return self.size -""" -------------------------------------------------- IteratorBase class --------------------------------------------------- -""" +# ----------------------------------------------- IteratorBase class --------------------------------------------------- class IteratorBase(ABC): @@ -320,9 +319,7 @@ class IteratorBase(ABC): pass -""" ------------------------------------------------- RandomIterator class -------------------------------------------------- -""" +# ---------------------------------------------- RandomIterator class -------------------------------------------------- class RandomIterator(IteratorBase): @@ -352,9 +349,7 @@ class RandomIterator(IteratorBase): np.random.shuffle(self.indices) -""" ---------------------------------------------------- Dataset class ------------------------------------------------------ -""" +# ------------------------------------------------- Dataset class ------------------------------------------------------ class Dataset: @@ -389,10 +384,12 @@ class Dataset: self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) - logging.info("output_types: {}".format(self.output_types)) - logging.info("output_shapes: {}".format(self.output_shapes)) + logging.info("output_types: {}", self.output_types) + logging.info("output_shapes: {}", self.output_shapes) # buffers + if self.size <= buffer_length: + buffer_length = self.size self.miner_buffer = Buffer(buffer_length) self.mining_lock = multiprocessing.Lock() self.consumer_buffer = Buffer(buffer_length) @@ -434,12 +431,12 @@ class Dataset: This function dumps the miner_buffer into the consumer_buffer, and restart the miner_thread """ # Wait for miner to finish his job - t = time.time() + date_t = time.time() self.miner_thread.join() - self.tot_wait += time.time() - t + self.tot_wait += time.time() - date_t # Copy miner_buffer.container --> consumer_buffer.container - self.consumer_buffer.container = [elem for elem in self.miner_buffer.container] + self.consumer_buffer.container = self.miner_buffer.container.copy() # Clear miner_buffer.container self.miner_buffer.container.clear() @@ -454,27 +451,24 @@ class Dataset: """ # Fill the miner_container until it's full while not self.miner_buffer.is_complete(): - try: - index = next(self.iterator) - with self.mining_lock: - new_sample = self.patches_reader.get_sample(index=index) - self.miner_buffer.add(new_sample) - except Exception as e: - logging.warning("Error during collecting samples: {}".format(e)) + index = next(self.iterator) + with self.mining_lock: + new_sample = self.patches_reader.get_sample(index=index) + self.miner_buffer.add(new_sample) def _summon_miner_thread(self): """ Create and starts the thread for the data collect """ - t = threading.Thread(target=self._collect) - t.start() - return t + new_thread = threading.Thread(target=self._collect) + new_thread.start() + return new_thread def _generator(self): """ Generator function, used for the tf dataset """ - for elem in range(self.size): + for _ in range(self.size): yield self.read_one_sample() def get_tf_dataset(self, batch_size, drop_remainder=True): @@ -486,7 +480,7 @@ class Dataset: """ if batch_size <= 2 * self.miner_buffer.max_length: logging.warning("Batch size is {} but dataset buffer has {} elements. Consider using a larger dataset " - "buffer to avoid I/O bottleneck".format(batch_size, self.miner_buffer.max_length)) + "buffer to avoid I/O bottleneck", batch_size, self.miner_buffer.max_length) return self.tf_dataset.batch(batch_size, drop_remainder=drop_remainder) def get_total_wait_in_seconds(self): @@ -497,9 +491,7 @@ class Dataset: return self.tot_wait -""" -------------------------------------------- DatasetFromPatchesImages class --------------------------------------------- -""" +# ----------------------------------------- DatasetFromPatchesImages class --------------------------------------------- class DatasetFromPatchesImages(Dataset): diff --git a/python/tricks.py b/python/tricks.py index fe4c5deaedb4095275f7161143a23ad968b10392..b31b14c39ddd89498983ea457c16b15f2c5409df 100644 --- a/python/tricks.py +++ b/python/tricks.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # ========================================================================== # -# Copyright 2018-2019 Remi Cresson (IRSTEA) -# Copyright 2020 Remi Cresson (INRAE) +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2021 INRAE # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,54 +17,42 @@ # limitations under the License. # # ==========================================================================*/ -import gdal -import numpy as np +""" +This module contains a set of python functions to interact with geospatial data +and TensorFlow models. +Starting from OTBTF >= 3.0.0, tricks is only used as a backward compatible stub +for TF 1.X versions. +""" import tensorflow.compat.v1 as tf from deprecated import deprecated - +from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds tf.disable_v2_behavior() +@deprecated(version="3.0.0", reason="Please use otbtf.read_image_as_np() instead") def read_image_as_np(filename, as_patches=False): """ - Read an image as numpy array. - @param filename File name of patches image - @param as_patches True if the image must be read as patches - @return 4D numpy array [batch, h, w, c] + Read a patches-image as numpy array. + :param filename: File name of the patches-image + :param as_patches: True if the image must be read as patches + :return 4D numpy array [batch, h, w, c] (batch = 1 when as_patches is False) """ # Open a GDAL dataset - ds = gdal.Open(filename) - if ds is None: - raise Exception("Unable to open file {}".format(filename)) - - # Raster infos - n_bands = ds.RasterCount - szx = ds.RasterXSize - szy = ds.RasterYSize - - # Raster array - myarray = ds.ReadAsArray() - - # Re-order bands (when there is > 1 band) - if (len(myarray.shape) == 3): - axes = (1, 2, 0) - myarray = np.transpose(myarray, axes=axes) + gdal_ds = gdal_open(filename) - if (as_patches): - n = int(szy / szx) - return myarray.reshape((n, szx, szx, n_bands)) - - return myarray.reshape((1, szy, szx, n_bands)) + # Return patches + return read_as_np_arr_from_gdal_ds(gdal_ds=gdal_ds, as_patches=as_patches) +@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build your nets") def create_savedmodel(sess, inputs, outputs, directory): """ - Create a SavedModel - @param sess TF session - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param directory Path for the generated SavedModel + Create a SavedModel from TF 1.X graphs + :param sess: The Tensorflow V1 session + :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) + :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"]) + :param directory: Path for the generated SavedModel """ print("Create a SavedModel in " + directory) graph = tf.compat.v1.get_default_graph() @@ -72,14 +60,16 @@ def create_savedmodel(sess, inputs, outputs, directory): outputs_names = {o: graph.get_tensor_by_name(o) for o in outputs} tf.compat.v1.saved_model.simple_save(sess, directory, inputs=inputs_names, outputs=outputs_names) + +@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build and save your nets") def ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False): """ - Read a Checkpoint and build a SavedModel - @param ckpt_path Path to the checkpoint file (without the ".meta" extension) - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param savedmodel_path Path for the generated SavedModel - @param clear_devices Clear TF devices positionning (True/False) + Read a Checkpoint and build a SavedModel for some TF 1.X graph + :param ckpt_path: Path to the checkpoint file (without the ".meta" extension) + :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) + :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"]) + :param savedmodel_path: Path for the generated SavedModel + :param clear_devices: Clear TensorFlow devices positioning (True/False) """ tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: @@ -90,33 +80,17 @@ def ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_device # Create a SavedModel create_savedmodel(sess, inputs=inputs, outputs=outputs, directory=savedmodel_path) -@deprecated + +@deprecated(version="3.0.0", reason="Please use otbtf.read_image_as_np() instead") def read_samples(filename): - """ + """ Read a patches image. @param filename: raster file name """ - return read_image_as_np(filename, as_patches=True) + return read_image_as_np(filename, as_patches=True) -@deprecated -def CreateSavedModel(sess, inputs, outputs, directory): - """ - Create a SavedModel - @param sess TF session - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param directory Path for the generated SavedModel - """ - create_savedmodel(sess, inputs, outputs, directory) -@deprecated -def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False): - """ - Read a Checkpoint and build a SavedModel - @param ckpt_path Path to the checkpoint file (without the ".meta" extension) - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param savedmodel_path Path for the generated SavedModel - @param clear_devices Clear TF devices positionning (True/False) - """ - ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices) +# Aliases for backward compatibility +# pylint: disable=invalid-name +CreateSavedModel = create_savedmodel +CheckpointToSavedModel = ckpt_to_savedmodel diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index aee4395220da9820a2f3e15acc6011b26b6528ee..a4260dfe9103806d0ca716022da525da0ee06fca 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,24 @@ otb_module_test() +# Unit tests +set(${otb-module}Tests + otbTensorflowTests.cxx + otbTensorflowCopyUtilsTests.cxx) + +add_executable(otbTensorflowTests ${${otb-module}Tests}) + +target_include_directories(otbTensorflowTests PRIVATE ${tensorflow_include_dir}) +target_link_libraries(otbTensorflowTests ${${otb-module}-Test_LIBRARIES} ${TENSORFLOW_CC_LIB} ${TENSORFLOW_FRAMEWORK_LIB}) +otb_module_target_label(otbTensorflowTests) + +# CopyUtilsTests +otb_add_test(NAME floatValueToTensorTest COMMAND otbTensorflowTests floatValueToTensorTest) +otb_add_test(NAME intValueToTensorTest COMMAND otbTensorflowTests intValueToTensorTest) +otb_add_test(NAME boolValueToTensorTest COMMAND otbTensorflowTests boolValueToTensorTest) +otb_add_test(NAME floatVecValueToTensorTest COMMAND otbTensorflowTests floatVecValueToTensorTest) +otb_add_test(NAME intVecValueToTensorTest COMMAND otbTensorflowTests intVecValueToTensorTest) +otb_add_test(NAME boolVecValueToTensorTest COMMAND otbTensorflowTests boolVecValueToTensorTest) + # Directories set(DATADIR ${CMAKE_CURRENT_SOURCE_DIR}/data) set(MODELSDIR ${CMAKE_CURRENT_SOURCE_DIR}/models) @@ -8,11 +27,20 @@ set(MODELSDIR ${CMAKE_CURRENT_SOURCE_DIR}/models) set(IMAGEXS ${DATADIR}/xs_subset.tif) set(IMAGEPAN ${DATADIR}/pan_subset.tif) set(IMAGEPXS ${DATADIR}/pxs_subset.tif) +set(IMAGEPXS2 ${DATADIR}/pxs_subset2.tif) +set(PATCHESA ${DATADIR}/Sentinel-2_B4328_10m_patches_A.jp2) +set(PATCHESB ${DATADIR}/Sentinel-2_B4328_10m_patches_B.jp2) +set(LABELSA ${DATADIR}/Sentinel-2_B4328_10m_labels_A.tif) +set(LABELSB ${DATADIR}/Sentinel-2_B4328_10m_labels_B.tif) +set(PATCHES01 ${DATADIR}/Sentinel-2_B4328_10m_patches_A.jp2) +set(PATCHES11 ${DATADIR}/Sentinel-2_B4328_10m_patches_B.jp2) # Input models set(MODEL1 ${MODELSDIR}/model1) set(MODEL2 ${MODELSDIR}/model2) set(MODEL3 ${MODELSDIR}/model3) +set(MODEL4 ${MODELSDIR}/model4) +set(MODEL5 ${MODELSDIR}/model5) # Output images and baselines set(MODEL1_PB_OUT apTvClTensorflowModelServeCNN16x16PB.tif) @@ -20,6 +48,92 @@ set(MODEL2_PB_OUT apTvClTensorflowModelServeCNN8x8_32x32PB.tif) set(MODEL2_FC_OUT apTvClTensorflowModelServeCNN8x8_32x32FC.tif) set(MODEL3_PB_OUT apTvClTensorflowModelServeFCNN16x16PB.tif) set(MODEL3_FC_OUT apTvClTensorflowModelServeFCNN16x16FC.tif) +set(MODEL4_FC_OUT apTvClTensorflowModelServeFCNN64x64to32x32.tif) +set(MODEL1_SAVED model1_updated) +set(PATCHESIMG_01 patchimg_01.tif) +set(PATCHESIMG_11 patchimg_11.tif) +set(MODEL5_OUT reduce_sum.tif) + +#----------- Patches selection ---------------- +set(PATCHESPOS_01 ${TEMP}/out_train_32.gpkg) +set(PATCHESPOS_02 ${TEMP}/out_valid_32.gpkg) +set(PATCHESPOS_11 ${TEMP}/out_train_33.gpkg) +set(PATCHESPOS_12 ${TEMP}/out_valid_33.gpkg) +# Even patches +otb_test_application(NAME PatchesSelectionEven + APP PatchesSelection + OPTIONS + -in ${IMAGEPXS2} + -grid.step 32 + -grid.psize 32 + -outtrain ${PATCHESPOS_01} + -outvalid ${PATCHESPOS_02} + ) + +# Odd patches +otb_test_application(NAME PatchesSelectionOdd + APP PatchesSelection + OPTIONS + -in ${IMAGEPXS2} + -grid.step 32 + -grid.psize 33 + -outtrain ${PATCHESPOS_11} + -outvalid ${PATCHESPOS_12} + ) + +#----------- Patches extraction ---------------- +# Even patches +otb_test_application(NAME PatchesExtractionEven + APP PatchesExtraction + OPTIONS + -source1.il ${IMAGEPXS2} + -source1.patchsizex 32 + -source1.patchsizey 32 + -source1.out ${TEMP}/${PATCHESIMG_01} + -vec ${PATCHESPOS_01} + -field id + VALID --compare-image ${EPSILON_6} + ${DATADIR}/${PATCHESIMG_01} + ${TEMP}/${PATCHESIMG_01} + ) + +# Odd patches +otb_test_application(NAME PatchesExtractionOdd + APP PatchesExtraction + OPTIONS + -source1.il ${IMAGEPXS2} + -source1.patchsizex 33 + -source1.patchsizey 33 + -source1.out ${TEMP}/${PATCHESIMG_11} + -vec ${PATCHESPOS_11} + -field id + VALID --compare-image ${EPSILON_6} + ${DATADIR}/${PATCHESIMG_11} + ${TEMP}/${PATCHESIMG_11} + ) + +#----------- Model training : 1-branch CNN (16x16) Patch-Based ---------------- +set(ENV{OTB_LOGGER_LEVEL} DEBUG) +otb_test_application(NAME TensorflowModelTrainCNN16x16PB + APP TensorflowModelTrain + OPTIONS + -training.epochs 10 + -training.source1.il ${PATCHESA} + -training.source1.placeholder "x" + -training.source2.il ${LABELSA} + -training.source2.placeholder "y" + -validation.source1.il ${PATCHESB} + -validation.source2.il ${LABELSB} + -validation.source1.name "x" + -validation.source2.name "prediction" + -training.source1.patchsizex 16 -training.source1.patchsizey 16 + -training.source2.patchsizex 1 -training.source2.patchsizey 1 + -model.dir ${MODEL1} + -model.saveto ${MODEL1_SAVED} + -training.targetnodes "optimizer" + -validation.mode "class" + ) +set_tests_properties(TensorflowModelTrainCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;$ENV{OTB_LOGGER_LEVEL}") #----------- Model serving : 1-branch CNN (16x16) Patch-Based ---------------- otb_test_application(NAME TensorflowModelServeCNN16x16PB @@ -31,6 +145,7 @@ otb_test_application(NAME TensorflowModelServeCNN16x16PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL1_PB_OUT} ${TEMP}/${MODEL1_PB_OUT}) +set_tests_properties(TensorflowModelServeCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") #----------- Model serving : 2-branch CNN (8x8, 32x32) Patch-Based ---------------- otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32PB @@ -44,7 +159,8 @@ otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL2_PB_OUT} ${TEMP}/${MODEL2_PB_OUT}) -set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32PB PROPERTIES ENVIRONMENT "OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") +set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") + #----------- Model serving : 2-branch CNN (8x8, 32x32) Fully-Conv ---------------- set(ENV{OTB_TF_NSOURCES} 2) @@ -59,7 +175,7 @@ otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32FC VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL2_FC_OUT} ${TEMP}/${MODEL2_FC_OUT}) -set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32FC PROPERTIES ENVIRONMENT "OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") +set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32FC PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") #----------- Model serving : 1-branch FCNN (16x16) Patch-Based ---------------- set(ENV{OTB_TF_NSOURCES} 1) @@ -72,6 +188,8 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN16x16PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL3_PB_OUT} ${TEMP}/${MODEL3_PB_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") + #----------- Model serving : 1-branch FCNN (16x16) Fully-conv ---------------- set(ENV{OTB_TF_NSOURCES} 1) @@ -84,5 +202,60 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN16x16FC VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL3_FC_OUT} ${TEMP}/${MODEL3_FC_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN16x16FC PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") + +#----------- Model serving : 1-branch FCNN (64x64)-->(32x32), Fully-conv ---------------- +set(ENV{OTB_TF_NSOURCES} 1) +otb_test_application(NAME apTvClTensorflowModelServeFCNN64x64to32x32 + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPXS2} + -source1.rfieldx 64 -source1.rfieldy 64 -source1.placeholder x + -output.efieldx 32 -output.efieldy 32 -output.names prediction_fcn + -model.dir ${MODEL4} -model.fullyconv on + -out ${TEMP}/${MODEL4_FC_OUT} + VALID --compare-image ${EPSILON_6} + ${DATADIR}/${MODEL4_FC_OUT} + ${TEMP}/${MODEL4_FC_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN64x64to32x32 PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") + +#----------- Test various output tensor shapes ---------------- +# We test the following output shapes on one monochannel image: +# [None] +# [None, 1] +# [None, None, None] +# [None, None, None, 1] +set(ENV{OTB_TF_NSOURCES} 1) +otb_test_application(NAME outputTensorShapesTest1pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest1fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest2pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_1" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest2fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_1" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest3pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_2" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest3fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_2" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest4pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_3" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest4fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_3" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) diff --git a/test/data/Sentinel-2_B4328_10m_labels_A.tif b/test/data/Sentinel-2_B4328_10m_labels_A.tif new file mode 100644 index 0000000000000000000000000000000000000000..fe3f31127413c33b44dc6522d628111ebc7ddc24 Binary files /dev/null and b/test/data/Sentinel-2_B4328_10m_labels_A.tif differ diff --git a/test/data/Sentinel-2_B4328_10m_labels_B.tif b/test/data/Sentinel-2_B4328_10m_labels_B.tif new file mode 100644 index 0000000000000000000000000000000000000000..501f44affb72fa1d4eed1c9f8607ac5284aab87a Binary files /dev/null and b/test/data/Sentinel-2_B4328_10m_labels_B.tif differ diff --git a/test/data/Sentinel-2_B4328_10m_patches_A.jp2 b/test/data/Sentinel-2_B4328_10m_patches_A.jp2 new file mode 100644 index 0000000000000000000000000000000000000000..b3e3d058a2d0c948c1d0e88732e6378945a13a79 Binary files /dev/null and b/test/data/Sentinel-2_B4328_10m_patches_A.jp2 differ diff --git a/test/data/Sentinel-2_B4328_10m_patches_B.jp2 b/test/data/Sentinel-2_B4328_10m_patches_B.jp2 new file mode 100644 index 0000000000000000000000000000000000000000..795af2ecd5908737b62574810b1701ca78eb3bd6 Binary files /dev/null and b/test/data/Sentinel-2_B4328_10m_patches_B.jp2 differ diff --git a/test/data/apTvClTensorflowModelServeCNN16x16PB.tif b/test/data/apTvClTensorflowModelServeCNN16x16PB.tif index 8d936fcbdf8406a5e3321e46c71ac108a7a7cc61..25b1e9a73435a0d9d29683af5a2e6e8bec29eb11 100644 Binary files a/test/data/apTvClTensorflowModelServeCNN16x16PB.tif and b/test/data/apTvClTensorflowModelServeCNN16x16PB.tif differ diff --git a/test/data/apTvClTensorflowModelServeFCNN64x64to32x32.tif b/test/data/apTvClTensorflowModelServeFCNN64x64to32x32.tif new file mode 100644 index 0000000000000000000000000000000000000000..1d22a3b97f6f2ded86651054be8259751ee0df11 Binary files /dev/null and b/test/data/apTvClTensorflowModelServeFCNN64x64to32x32.tif differ diff --git a/test/data/patchimg_01.tif b/test/data/patchimg_01.tif new file mode 100644 index 0000000000000000000000000000000000000000..f4e5262ee17b80ef1bb72cfbcf8d8d899f73a2c2 Binary files /dev/null and b/test/data/patchimg_01.tif differ diff --git a/test/data/patchimg_11.tif b/test/data/patchimg_11.tif new file mode 100644 index 0000000000000000000000000000000000000000..65b55565f55641cdb6cf496a73e18c919d4191fc Binary files /dev/null and b/test/data/patchimg_11.tif differ diff --git a/test/data/pxs_subset2.tif b/test/data/pxs_subset2.tif new file mode 100644 index 0000000000000000000000000000000000000000..64991c0567ca34217558a502a7071eb34e241c21 Binary files /dev/null and b/test/data/pxs_subset2.tif differ diff --git a/test/models/model1/SavedModel_cnn/saved_model.pb b/test/models/model1/SavedModel_cnn/saved_model.pb deleted file mode 100644 index 7674fb362182be7320d9e754dbaad36b8fe7839f..0000000000000000000000000000000000000000 Binary files a/test/models/model1/SavedModel_cnn/saved_model.pb and /dev/null differ diff --git a/test/models/model1/SavedModel_cnn/variables/variables.data-00000-of-00001 b/test/models/model1/SavedModel_cnn/variables/variables.data-00000-of-00001 deleted file mode 100644 index 437b992435c708ddae6b776a36c7fe89bb38bf8a..0000000000000000000000000000000000000000 Binary files a/test/models/model1/SavedModel_cnn/variables/variables.data-00000-of-00001 and /dev/null differ diff --git a/test/models/model1/SavedModel_cnn/variables/variables.index b/test/models/model1/SavedModel_cnn/variables/variables.index deleted file mode 100644 index 759c02f3cfc05ac212fb978e7a1864735763d24c..0000000000000000000000000000000000000000 Binary files a/test/models/model1/SavedModel_cnn/variables/variables.index and /dev/null differ diff --git a/test/models/model1/saved_model.pb b/test/models/model1/saved_model.pb index 7674fb362182be7320d9e754dbaad36b8fe7839f..b22330c86c0b108daf4ea562bb5fda88b97879e8 100644 Binary files a/test/models/model1/saved_model.pb and b/test/models/model1/saved_model.pb differ diff --git a/test/models/model2/saved_model.pb b/test/models/model2/saved_model.pb index 7269539b35274cae39f77778f75fab344527800c..c809fdcef72a36d7a2f2d19a70c7f39fbce0a609 100644 Binary files a/test/models/model2/saved_model.pb and b/test/models/model2/saved_model.pb differ diff --git a/test/models/model3/saved_model.pb b/test/models/model3/saved_model.pb index 099d0ce64a10d36cdf29ad14e6a3f4a570e794b5..84e1bba53549a0c811fb9d7ccf46bb4e3e960b45 100644 Binary files a/test/models/model3/saved_model.pb and b/test/models/model3/saved_model.pb differ diff --git a/test/models/model4/saved_model.pb b/test/models/model4/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..d6920be621f6ab486ed49abce1b97332306f5ac9 Binary files /dev/null and b/test/models/model4/saved_model.pb differ diff --git a/test/models/model4/variables/variables.data-00000-of-00001 b/test/models/model4/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..2837c8c97d562b7d107dd81636816db722d7ae9c Binary files /dev/null and b/test/models/model4/variables/variables.data-00000-of-00001 differ diff --git a/test/models/model4/variables/variables.index b/test/models/model4/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..25810ce3e68510179ea8ee45fd3f5c24c45d7bb9 Binary files /dev/null and b/test/models/model4/variables/variables.index differ diff --git a/test/models/model5.py b/test/models/model5.py new file mode 100644 index 0000000000000000000000000000000000000000..cc17d52edce5202befaee4f3fc8f7780fbc06f30 --- /dev/null +++ b/test/models/model5.py @@ -0,0 +1,25 @@ +""" +This test checks that the output tensor shapes are supported. +The input of this model must be a mono channel image. +All 4 different output shapes supported in OTBTF are tested. + +""" +import tensorflow as tf + +# Input +x = tf.keras.Input(shape=[None, None, None], name="x") # [b, h, w, c=1] + +# Create reshaped outputs +shape = tf.shape(x) +b = shape[0] +h = shape[1] +w = shape[2] +y1 = tf.reshape(x, shape=(b*h*w,)) # [b*h*w] +y2 = tf.reshape(x, shape=(b*h*w, 1)) # [b*h*w, 1] +y3 = tf.reshape(x, shape=(b, h, w)) # [b, h, w] +y4 = tf.reshape(x, shape=(b, h, w, 1)) # [b, h, w, 1] + +# Create model +model = tf.keras.Model(inputs={"x": x}, outputs={"y1": y1, "y2": y2, "y3": y3, "y4": y4}) +model.save("model5") + diff --git a/test/models/model5/saved_model.pb b/test/models/model5/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..5c4373bc5152c7d08d7afc1fb719916c365e9a81 Binary files /dev/null and b/test/models/model5/saved_model.pb differ diff --git a/test/models/model5/variables/variables.data-00000-of-00001 b/test/models/model5/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..8025a877597414e2ccfa5ad6df4d353692f8b3fc Binary files /dev/null and b/test/models/model5/variables/variables.data-00000-of-00001 differ diff --git a/test/models/model5/variables/variables.index b/test/models/model5/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..b5dcd195566dfe545927b250de2a228c55c829c4 Binary files /dev/null and b/test/models/model5/variables/variables.index differ diff --git a/test/otbTensorflowCopyUtilsTests.cxx b/test/otbTensorflowCopyUtilsTests.cxx new file mode 100644 index 0000000000000000000000000000000000000000..5b9586460aa1e0cc51a115ea1a6031ebd3cf805e --- /dev/null +++ b/test/otbTensorflowCopyUtilsTests.cxx @@ -0,0 +1,116 @@ +/*========================================================================= + + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbTensorflowCopyUtils.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "itkMacro.h" + +template<typename T> +int compare(tensorflow::Tensor & t1, tensorflow::Tensor & t2) +{ + std::cout << "Compare " << t1.DebugString() << " and " << t2.DebugString() << std::endl; + if (t1.dims() != t2.dims()) + { + std::cout << "dims() differ!" << std::endl; + return EXIT_FAILURE; + } + if (t1.dtype() != t2.dtype()) + { + std::cout << "dtype() differ!" << std::endl; + return EXIT_FAILURE; + } + if (t1.NumElements() != t2.NumElements()) + { + std::cout << "NumElements() differ!" << std::endl; + return EXIT_FAILURE; + } + for (unsigned int i = 0; i < t1.NumElements(); i++) + if (t1.flat<T>()(i) != t2.flat<T>()(i)) + { + std::cout << "scalar " << i << " differ!" << std::endl; + return EXIT_FAILURE; + } + // Else + std::cout << "Tensors are equals :)" << std::endl; + return EXIT_SUCCESS; +} + +template<typename T> +int genericValueToTensorTest(tensorflow::DataType dt, std::string expr, T value) +{ + tensorflow::Tensor t = otb::tf::ValueToTensor(expr); + tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({1})); + t_ref.scalar<T>()() = value; + + return compare<T>(t, t_ref); +} + +int floatValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest<float>(tensorflow::DT_FLOAT, "0.1234", 0.1234) + && genericValueToTensorTest<float>(tensorflow::DT_FLOAT, "-0.1234", -0.1234) ; +} + +int intValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest<int>(tensorflow::DT_INT32, "1234", 1234) + && genericValueToTensorTest<int>(tensorflow::DT_INT32, "-1234", -1234); +} + +int boolValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "true", true) + && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "True", true) + && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "False", false) + && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "false", false); +} + +template<typename T> +int genericVecValueToTensorTest(tensorflow::DataType dt, std::string expr, std::vector<T> values, std::size_t size) +{ + tensorflow::Tensor t = otb::tf::ValueToTensor(expr); + tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({size})); + unsigned int i = 0; + for (auto value: values) + { + t_ref.flat<T>()(i) = value; + i++; + } + + return compare<T>(t, t_ref); +} + +int floatVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest<float>(tensorflow::DT_FLOAT, + "(0.1234, -1,-20,2.56 ,3.5)", + std::vector<float>({0.1234, -1, -20, 2.56 ,3.5}), + 5); +} + +int intVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest<int>(tensorflow::DT_INT32, + "(1234, -1,-20,256 ,35)", + std::vector<int>({1234, -1, -20, 256 ,35}), + 5); +} + +int boolVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest<bool>(tensorflow::DT_BOOL, + "(true, false,True, False)", + std::vector<bool>({true, false, true, false}), + 4); +} + + diff --git a/test/otbTensorflowTests.cxx b/test/otbTensorflowTests.cxx new file mode 100644 index 0000000000000000000000000000000000000000..50e9a91a57b85feecccf44b49b5c3b57f7e69ec3 --- /dev/null +++ b/test/otbTensorflowTests.cxx @@ -0,0 +1,23 @@ +/*========================================================================= + + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbTestMain.h" + +void RegisterTests() +{ + REGISTER_TEST(floatValueToTensorTest); + REGISTER_TEST(intValueToTensorTest); + REGISTER_TEST(boolValueToTensorTest); + REGISTER_TEST(floatVecValueToTensorTest); + REGISTER_TEST(intVecValueToTensorTest); + REGISTER_TEST(boolVecValueToTensorTest); +} + diff --git a/test/sr4rs_unittest.py b/test/sr4rs_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb921f8451cc83b3fd7b9e9e90bf61755511eca --- /dev/null +++ b/test/sr4rs_unittest.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import unittest +import os +from pathlib import Path +import gdal +import otbApplication as otb + + +def command_train_succeed(extra_opts=""): + root_dir = os.environ["CI_PROJECT_DIR"] + ckpt_dir = "/tmp/" + + def _input(file_name): + return "{}/sr4rs_data/input/{}".format(root_dir, file_name) + + command = "python {}/sr4rs/code/train.py ".format(root_dir) + command += "--lr_patches " + command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s2.jp2 ") + command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s2.jp2 ") + command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s2.jp2 ") + command += "--hr_patches " + command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s6_cal.jp2 ") + command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s6_cal.jp2 ") + command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s6_cal.jp2 ") + command += "--save_ckpt {} ".format(ckpt_dir) + command += "--depth 4 " + command += "--nresblocks 1 " + command += "--epochs 1 " + command += extra_opts + os.system(command) + file = Path("{}/checkpoint".format(ckpt_dir)) + return file.is_file() + + +class SR4RSv1Test(unittest.TestCase): + + def test_train_nostream(self): + self.assertTrue(command_train_succeed()) + + def test_train_stream(self): + self.assertTrue(command_train_succeed(extra_opts="--streaming")) + + def test_inference(self): + root_dir = os.environ["CI_PROJECT_DIR"] + out_img = "/tmp/sr4rs.tif" + baseline = "{}/sr4rs_data/baseline/sr4rs.tif".format(root_dir) + + command = "python {}/sr4rs/code/sr.py ".format(root_dir) + command += "--input {}/sr4rs_data/input/".format(root_dir) + command += "SENTINEL2B_20200929-104857-489_L2A_T31TEJ_C_V2-2_FRE_10m.tif " + command += "--savedmodel {}/sr4rs_sentinel2_bands4328_france2020_savedmodel/ ".format(root_dir) + command += "--output '{}?&box=256:256:512:512'".format(out_img) + os.system(command) + + nbchannels_reconstruct = gdal.Open(out_img).RasterCount + nbchannels_baseline = gdal.Open(baseline).RasterCount + + self.assertTrue(nbchannels_reconstruct == nbchannels_baseline) + + for i in range(1, 1 + nbchannels_baseline): + comp = otb.Registry.CreateApplication('CompareImages') + comp.SetParameterString('ref.in', baseline) + comp.SetParameterInt('ref.channel', i) + comp.SetParameterString('meas.in', out_img) + comp.SetParameterInt('meas.channel', i) + comp.Execute() + mae = comp.GetParameterFloat('mae') + + self.assertTrue(mae < 0.01) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/docker/README.md b/tools/docker/README.md index 8722b52e539303319b50a9de8ca3f11ff721597e..4246b74d33b88ee1daa38c2f3de7ab44911f0917 100644 --- a/tools/docker/README.md +++ b/tools/docker/README.md @@ -74,7 +74,7 @@ docker build --network='host' -t otbtf:oldstable-gpu --build-arg BASE_IMG=nvidia ### Build for another machine and save TF compiled files ```bash -# Use same ubuntu and CUDA version than your target machine, beware of CC optimization and CPU compatibilty +# Use same ubuntu and CUDA version than your target machine, beware of CC optimization and CPU compatibility # (set env variable CC_OPT_FLAGS and avoid "-march=native" if your Docker's CPU is optimized with AVX2/AVX512 but your target CPU isn't) docker build --network='host' -t otbtf:custom --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 \ --build-arg TF=v2.5.0 --build-arg ZIP_TF_BIN=true . @@ -146,7 +146,7 @@ $ mapla ``` ## Common errors -Buid : +Build : `Error response from daemon: manifest for nvidia/cuda:11.0-cudnn8-devel-ubuntu20.04 not found: manifest unknown: manifest unknown` => Image is missing from dockerhub