diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 7f713d76a4f72de21ae0dceafb3f985e1e023653..c2f01c4b89654b81394f48adcbb3a617cd0ac6f6 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,20 +1,22 @@ image: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME variables: - OTBTF_VERSION: 3.2.1 + OTBTF_VERSION: 3.3.0 OTB_BUILD: /src/otb/build/OTB/build # Local OTB build directory OTBTF_SRC: /src/otbtf # Local OTBTF source directory OTB_TEST_DIR: $OTB_BUILD/Testing/Temporary # OTB testing directory ARTIFACT_TEST_DIR: $CI_PROJECT_DIR/testing CRC_BOOK_TMP: /tmp/crc_book_tests_tmp + API_TEST_TMP: /tmp/api_tests_tmp + DATADIR: $CI_PROJECT_DIR/test/data DOCKER_BUILDKIT: 1 DOCKER_DRIVER: overlay2 CACHE_IMAGE_BASE: $CI_REGISTRY_IMAGE:otbtf-base CACHE_IMAGE_BUILDER: $CI_REGISTRY_IMAGE:builder BRANCH_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME DEV_IMAGE: $CI_REGISTRY_IMAGE:cpu-basic-dev-testing - CI_REGISTRY_PUBIMG: $CI_REGISTRY_IMAGE/$OTBTF_VERSION - DOCKERHUB_IMAGE_BASE: mdl4eo/otbtf${OTBTF_VERSION} + CI_REGISTRY_PUBIMG: $CI_REGISTRY_IMAGE:$OTBTF_VERSION + DOCKERHUB_IMAGE_BASE: mdl4eo/otbtf:${OTBTF_VERSION} workflow: rules: @@ -140,7 +142,7 @@ crc_book: extends: .applications_test_base script: - mkdir -p $CRC_BOOK_TMP - - TMPDIR=$CRC_BOOK_TMP DATADIR=$CI_PROJECT_DIR/test/data python -m pytest --junitxml=$CI_PROJECT_DIR/report_tutorial.xml $OTBTF_SRC/test/tutorial_unittest.py + - TMPDIR=$CRC_BOOK_TMP python -m pytest --junitxml=$CI_PROJECT_DIR/report_tutorial.xml $OTBTF_SRC/test/tutorial_unittest.py after_script: - cp $CRC_BOOK_TMP/*.* $ARTIFACT_TEST_DIR/ @@ -157,6 +159,14 @@ sr4rs: - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py +otbtf_api: + extends: .applications_test_base + script: + - mkdir $API_TEST_TMP + - TMPDIR=$API_TEST_TMP python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_api.xml $OTBTF_SRC/test/api_unittest.py + after_script: + - cp $API_TEST_TMP/*.* $ARTIFACT_TEST_DIR/ + deploy_cpu-dev-testing: stage: Update dev image extends: .docker_build_base @@ -176,10 +186,10 @@ deploy_cpu-dev-testing: deploy_cpu: extends: .ship base variables: - IMAGE_CPU: $CI_REGISTRY_PUBIMG:cpu - IMAGE_CPUDEV: $CI_REGISTRY_PUBIMG:cpu-dev - DOCKERHUB_CPU: $DOCKERHUB_IMAGE_BASE:cpu - DOCKERHUB_CPUDEV: $DOCKERHUB_IMAGE_BASE:cpu-dev + IMAGE_CPU: $CI_REGISTRY_PUBIMG-cpu + IMAGE_CPUDEV: $CI_REGISTRY_PUBIMG-cpu-dev + DOCKERHUB_CPU: $DOCKERHUB_IMAGE_BASE-cpu + DOCKERHUB_CPUDEV: $DOCKERHUB_IMAGE_BASE-cpu-dev script: # cpu - docker build --network='host' --tag $IMAGE_CPU --build-arg BASE_IMG=ubuntu:20.04 --build-arg BZL_CONFIGS="" . @@ -197,12 +207,12 @@ deploy_cpu: deploy_gpu: extends: .ship base variables: - IMAGE_GPU: $CI_REGISTRY_PUBIMG:gpu - IMAGE_GPUDEV: $CI_REGISTRY_PUBIMG:gpu-dev - IMAGE_GPUOPT: $CI_REGISTRY_PUBIMG:gpu-opt - IMAGE_GPUOPTDEV: $CI_REGISTRY_PUBIMG:gpu-opt-dev - DOCKERHUB_GPU: $DOCKERHUB_IMAGE_BASE:gpu - DOCKERHUB_GPUDEV: $DOCKERHUB_IMAGE_BASE:gpu-dev + IMAGE_GPU: $CI_REGISTRY_PUBIMG-gpu + IMAGE_GPUDEV: $CI_REGISTRY_PUBIMG-gpu-dev + IMAGE_GPUOPT: $CI_REGISTRY_PUBIMG-gpu-opt + IMAGE_GPUOPTDEV: $CI_REGISTRY_PUBIMG-gpu-opt-dev + DOCKERHUB_GPU: $DOCKERHUB_IMAGE_BASE-gpu + DOCKERHUB_GPUDEV: $DOCKERHUB_IMAGE_BASE-gpu-dev script: # gpu-opt - docker build --network='host' --tag $IMAGE_GPUOPT --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 . diff --git a/Dockerfile b/Dockerfile index da634cea3fb7dbbce68b3660b9900e2d60a0d837..990c55f597646eaa17ea7b23c33bbf82442a8bc3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -85,7 +85,7 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi ### OTB ARG GUI=false -ARG OTB=7.4.0 +ARG OTB=8.0.1 ARG OTBTESTS=false RUN mkdir /src/otb @@ -149,7 +149,7 @@ COPY --from=builder /src /src # System-wide ENV ENV PATH="/opt/otbtf/bin:$PATH" ENV LD_LIBRARY_PATH="/opt/otbtf/lib:$LD_LIBRARY_PATH" -ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf" +ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/python3/dist-packages:/opt/otbtf/lib/otb/python:/src/otbtf" ENV OTB_APPLICATION_PATH="/opt/otbtf/lib/otb/applications" # Default user, directory and command (bash is the entrypoint when using 'docker create') diff --git a/README.md b/README.md index 343c4863ccad0ea295b9a29b243e3cc518e9c0e0..c2a13a9c918de187be6b46ab147f1f9542729fc0 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # OTBTF: Orfeo ToolBox meets TensorFlow [](https://opensource.org/licenses/Apache-2.0) -[](https://gitlab.irstea.fr/remi.cresson/otbtf/-/commits/develop) +[](https://gitlab.irstea.fr/remi.cresson/otbtf/-/commits/develop) This remote module of the [Orfeo ToolBox](https://www.orfeo-toolbox.org) provides a generic, multi purpose deep learning framework, targeting remote sensing images processing. It contains a set of new process objects that internally invoke [Tensorflow](https://www.tensorflow.org/), and a bunch of user-oriented applications to perform deep learning with real-world remote sensing images. @@ -19,9 +19,12 @@ Applications can be used to build OTB pipelines from Python or C++ APIs. ### Python -`otbtf.py` targets python developers that want to train their own model from python with TensorFlow or Keras. +The `otbtf` module targets python developers that want to train their own model from python with TensorFlow or Keras. It provides various classes for datasets and iterators to handle the _patches images_ generated from the `PatchesExtraction` OTB application. -For instance, the `otbtf.Dataset` class provides a method `get_tf_dataset()` which returns a `tf.dataset` that can be used in your favorite TensorFlow pipelines, or convert your patches into TFRecords. +For instance, the `otbtf.DatasetFromPatchesImages` can be instantiated from a set of _patches images_ +and delivering samples as `tf.dataset` that can be used in your favorite TensorFlow pipelines, or convert your patches into TFRecords. +The `otbtf.TFRecords` enables you train networks from TFRecords files, which is quite suited for +distributed training. Read more in the [tutorial for keras](otbtf/examples/tensorflow_v2x/fcnn/README.md). `tricks.py` is here for backward compatibility with codes based on OTBTF 1.x and 2.x. @@ -36,6 +39,10 @@ Below are some screen captures of deep learning applications performed at large  + - Sentinel-2 reconstruction with Sentinel-1 VV/VH with the [Decloud software](https://github.com/CNES/decloud), which is based on OTBTF + + + - - Image to image translation (Spot-7 image --> Wikimedia Map using CGAN. So unnecessary but fun!)  @@ -46,9 +53,9 @@ For now you have two options: either use the existing **docker image**, or build ### Docker -Use the latest image from dockerhub: +Use the latest CPU or GPU-enabled image from dockerhub: ``` -docker run mdl4eo/otbtf3.1:cpu-basic otbcli_PatchesExtraction -help +docker run mdl4eo/otbtf:3.3.0-cpu otbcli_PatchesExtraction -help ``` Read more in the [docker use documentation](doc/DOCKERUSE.md). diff --git a/RELEASE_NOTES.txt b/RELEASE_NOTES.txt index 6e542468705714e4e8cb7339dd6546bd2f76e2ea..5d78c0d733e87760365b96c3d87dd4e705e2c1a8 100644 --- a/RELEASE_NOTES.txt +++ b/RELEASE_NOTES.txt @@ -1,3 +1,13 @@ +Version 3.3.0 (27 jul 2022) +---------------------------------------------------------------- +* Improves the `dataset` classes (`DatasetFromPatchesImages`, `TFRecords`) to use them easily in keras +* Add the `ModelBase` class, which eases considerably the creation of deep nets for Keras/TF/TensorflowModelServe +* Add an example explaining how to use python classes to build and train models with Keras, and use models in OTB. +* Document the python API (`otbtf.dataset`, `otbtf.tfrecords`, `otbtf.ModelBase`) +* Test the python API in the CI, using the (XS, labels) patches of the Amsterdam dataset from CRC book +* Upgrade OTB to version 8.0.1 +* Upgrade GDAL to version 3.4.2 + Version 3.2.1 (1 jun 2022) ---------------------------------------------------------------- * Changing docker images naming convention (cpu/gpu-basic* --> cpu/gpu*, cpu/gpu* --> cpu/gpu-opt*) + only images without optimizations are pushed on dockerhub diff --git a/app/otbPatchesSelection.cxx b/app/otbPatchesSelection.cxx index 68d76221dbf6b7ffbb828a751420a854d58864a7..3437849b55419d6d9c0e9038e24c94506201ca8d 100644 --- a/app/otbPatchesSelection.cxx +++ b/app/otbPatchesSelection.cxx @@ -35,6 +35,12 @@ #include <random> #include <limits> +namespace otb +{ + +namespace Wrapper +{ + // Functor to retrieve nodata template<class TPixel, class OutputPixel> class IsNoData @@ -62,12 +68,6 @@ private: typename TPixel::ValueType m_NoDataValue; }; -namespace otb -{ - -namespace Wrapper -{ - class PatchesSelection : public Application { public: diff --git a/doc/DOCKERUSE.md b/doc/DOCKERUSE.md index 457e6e01978f39fab39b791de4099f85950640aa..50e505de625df4839f106fd3f659c9c585c11675 100644 --- a/doc/DOCKERUSE.md +++ b/doc/DOCKERUSE.md @@ -2,44 +2,23 @@ ### Available images -Here is the list of OTBTF docker images hosted on [dockerhub](https://hub.docker.com/u/mdl4eo). +Here is the list of the latest OTBTF docker images hosted on [dockerhub](https://hub.docker.com/u/mdl4eo). Since OTBTF >= 3.2.1 you can find latest docker images on [gitlab.irstea.fr](https://gitlab.irstea.fr/remi.cresson/otbtf/container_registry). -| Name | Os | TF | OTB | Description | Dev files | Compute capability | -| --------------------------------- | ------------- | ------ | ----- | ---------------------- | --------- | ------------------ | -| **mdl4eo/otbtf1.6:cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | yes | 5.2,6.1,7.0 | -| **mdl4eo/otbtf1.7:cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | yes | 5.2,6.1,7.0 | -| **mdl4eo/otbtf1.7:gpu** | Ubuntu Xenial | r1.14 | 7.0.0 | GPU | yes | 5.2,6.1,7.0 | -| **mdl4eo/otbtf2.0:cpu** | Ubuntu Xenial | r2.1 | 7.1.0 | CPU, no optimization | yes | 5.2,6.1,7.0,7.5 | -| **mdl4eo/otbtf2.0:gpu** | Ubuntu Xenial | r2.1 | 7.1.0 | GPU | yes | 5.2,6.1,7.0,7.5 | -| **mdl4eo/otbtf2.4:cpu-basic** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, no optimization | yes | 5.2,6.1,7.0,7.5 | -| **mdl4eo/otbtf2.4:cpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, few optimizations | no | 5.2,6.1,7.0,7.5 | -| **mdl4eo/otbtf2.4:cpu-mkl** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, Intel MKL, AVX512 | yes | 5.2,6.1,7.0,7.5 | -| **mdl4eo/otbtf2.4:gpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | GPU | yes | 5.2,6.1,7.0,7.5 | -| **mdl4eo/otbtf2.5:cpu-basic** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf2.5:cpu-basic-dev** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf2.5:cpu** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, few optimization | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf2.5:gpu** | Ubuntu Focal | r2.5 | 7.4.0 | GPU | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf2.5:gpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.0:cpu-basic** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.0:cpu-basic-dev** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.0:gpu** | Ubuntu Focal | r2.5 | 7.4.0 | GPU | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.0:gpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:cpu-basic** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:cpu-basic-dev** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:gpu-basic** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:gpu-basic-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:gpu** | Ubuntu Focal | r2.8 | 7.4.0 | GPU | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:gpu-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.2.1:cpu** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.2.1:cpu-dev** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.2.1:gpu** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.2.1:gpu-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **gitlab.irstea.fr/remi.cresson/otbtf/container_registry/otbtf3.2.1:gpu-opt** | Ubuntu Focal | r2.8 | 7.4.0 | GPU with opt. | no | 5.2,6.1,7.0,7.5,8.6| -| **gitlab.irstea.fr/remi.cresson/otbtf/container_registry/otbtf3.2.1:gpu-opt-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU with opt. (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| Name | Os | TF | OTB | Description | Dev files | Compute capability | +|------------------------------------------------------------------------------------| ------------- | ------ |-------| ---------------------- | --------- | ------------------ | +| **mdl4eo/otbtf:3.3.0-cpu** | Ubuntu Focal | r2.8 | 8.0.1 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.3.0-cpu-dev** | Ubuntu Focal | r2.8 | 8.0.1 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.3.0-gpu** | Ubuntu Focal | r2.8 | 8.0.1 | GPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.3.0-gpu-dev** | Ubuntu Focal | r2.8 | 8.0.1 | GPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **gitlab.irstea.fr/remi.cresson/otbtf/container_registry/otbtf:3.3.0-gpu-opt** | Ubuntu Focal | r2.8 | 8.0.1 | GPU with opt. | no | 5.2,6.1,7.0,7.5,8.6| +| **gitlab.irstea.fr/remi.cresson/otbtf/container_registry/otbtf:3.3.0-gpu-opt-dev** | Ubuntu Focal | r2.8 | 8.0.1 | GPU with opt. (dev) | yes | 5.2,6.1,7.0,7.5,8.6| + +The list of older releases is available [here](#older-docker-releases). You can also find more interesting OTBTF flavored images at [LaTelescop gitlab registry](https://gitlab.com/latelescop/docker/otbtf/container_registry/). + ### Development ready images Until r2.4, all images are development-ready, and the sources are located in `/work/`. @@ -58,7 +37,7 @@ For instance, suppose you have some data in `/mnt/my_device/` that you want to u The following command shows you how to access the folder from the docker image. ```bash -docker run -v /mnt/my_device/:/data/ -ti mdl4eo/otbtf3.2.1:cpu bash -c "ls /data" +docker run -v /mnt/my_device/:/data/ -ti mdl4eo/otbtf:3.3.0-cpu bash -c "ls /data" ``` Beware of ownership issues! see the last section of this doc. @@ -71,13 +50,13 @@ You can then use the OTBTF `gpu` tagged docker images with the **NVIDIA runtime* With Docker version earlier than 19.03 : ```bash -docker run --runtime=nvidia -ti mdl4eo/otbtf3.2.1:gpu bash +docker run --runtime=nvidia -ti mdl4eo/otbtf:3.3.0-gpu bash ``` With Docker version including and after 19.03 : ```bash -docker run --gpus all -ti mdl4eo/otbtf3.2.1:gpu bash +docker run --gpus all -ti mdl4eo/otbtf:3.3.0-gpu bash ``` You can find some details on the **GPU docker image** and some **docker tips and tricks** on [this blog](https://mdl4eo.irstea.fr/2019/10/15/otbtf-docker-image-with-gpu/). @@ -90,7 +69,7 @@ Be careful though, these infos might be a bit outdated... 1. Install [WSL2](https://docs.microsoft.com/en-us/windows/wsl/install-win10#manual-installation-steps) (Windows Subsystem for Linux) 2. Install [docker desktop](https://www.docker.com/products/docker-desktop) 3. Start **docker desktop** and **enable WSL2** from *Settings* > *General* then tick the box *Use the WSL2 based engine* -3. Open a **cmd.exe** or **PowerShell** terminal, and type `docker create --name otbtf-cpu --interactive --tty mdl4eo/otbtf3.2.1:cpu` +3. Open a **cmd.exe** or **PowerShell** terminal, and type `docker create --name otbtf-cpu --interactive --tty mdl4eo/otbtf:3.3.0-cpu` 4. Open **docker desktop**, and check that the docker is running in the **Container/Apps** menu  5. From **docker desktop**, click on the icon highlighted as shown below, and use the bash terminal that should pop up! @@ -139,12 +118,12 @@ sudo systemctl {status,enable,disable,start,stop} docker Run a simple command in a one-shot container: ```bash -docker run mdl4eo/otbtf3.2.1:cpu otbcli_PatchesExtraction +docker run mdl4eo/otbtf:3.3.0-cpu otbcli_PatchesExtraction ``` You can also use the image in interactive mode with bash: ```bash -docker run -ti mdl4eo/otbtf3.2.1:cpu bash +docker run -ti mdl4eo/otbtf:3.3.0-cpu bash ``` ### Persistent container @@ -154,7 +133,7 @@ Beware of ownership issues, see the last section of this doc. ```bash docker create --interactive --tty --volume /home/$USER:/home/otbuser/ \ - --name otbtf mdl4eo/otbtf3.2.1:cpu /bin/bash + --name otbtf mdl4eo/otbtf:3.3.0-cpu /bin/bash ``` ### Interactive session @@ -218,7 +197,7 @@ Create a named container (here with your HOME as volume), Docker will automatica ```bash docker create --interactive --tty --volume /home/$USER:/home/otbuser \ - --name otbtf mdl4eo/otbtf3.2.1:cpu /bin/bash + --name otbtf mdl4eo/otbtf:3.3.0-cpu /bin/bash ``` Start a background container process: @@ -255,3 +234,34 @@ id ls -Alh /home/otbuser touch /home/otbuser/test.txt ``` + +# Older docker releases + +Here you can find the list of older releases of OTBTF: + +| Name | Os | TF | OTB | Description | Dev files | Compute capability | +|------------------------------------------------------------------------------------| ------------- | ------ |-------| ---------------------- | --------- | ------------------ | +| **mdl4eo/otbtf:1.6-cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | yes | 5.2,6.1,7.0 | +| **mdl4eo/otbtf:1.7-cpu** | Ubuntu Xenial | r1.14 | 7.0.0 | CPU, no optimization | yes | 5.2,6.1,7.0 | +| **mdl4eo/otbtf:1.7-gpu** | Ubuntu Xenial | r1.14 | 7.0.0 | GPU | yes | 5.2,6.1,7.0 | +| **mdl4eo/otbtf:2.0-cpu** | Ubuntu Xenial | r2.1 | 7.1.0 | CPU, no optimization | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf:2.0-gpu** | Ubuntu Xenial | r2.1 | 7.1.0 | GPU | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf:2.4-cpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, no optimization | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf:2.4-cpu-opt** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, few optimizations | no | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf:2.4-cpu-mkl** | Ubuntu Focal | r2.4.1 | 7.2.0 | CPU, Intel MKL, AVX512 | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf:2.4-gpu** | Ubuntu Focal | r2.4.1 | 7.2.0 | GPU | yes | 5.2,6.1,7.0,7.5 | +| **mdl4eo/otbtf:2.5-cpu** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:2.5:cpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:2.5-cpu-opt** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, few optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:2.5-gpu-opt** | Ubuntu Focal | r2.5 | 7.4.0 | GPU | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:2.5-gpu-opt-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.0-cpu** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.0-cpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.0-gpu-opt** | Ubuntu Focal | r2.5 | 7.4.0 | GPU | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.0-gpu-opt-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.1-cpu** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.1-cpu-dev** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.1-gpu** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.1-gpu-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.1-gpu-opt** | Ubuntu Focal | r2.8 | 7.4.0 | GPU | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf:3.1-gpu-opt-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| \ No newline at end of file diff --git a/doc/EXAMPLES.md b/doc/EXAMPLES.md index 0f5282f3dee96d6f6b8776abe7d6ef2fce8a6203..d48c56a9273794a20dbf498fa2cbab371983c963 100644 --- a/doc/EXAMPLES.md +++ b/doc/EXAMPLES.md @@ -312,5 +312,6 @@ otbcli_TensorflowModelServe \ -source2.il $pan -source2.rfieldx 32 -source2.rfieldy 32 -source2.placeholder "x2" \ -model.dir $modeldir \ -model.fullyconv on \ +-output.names "prediction" \ -out $output_classif ``` diff --git a/otbtf/__init__.py b/otbtf/__init__.py index ac36018a1e7f401c3aaaab9d0809f9541a0c376f..5321e3dba52898934094bdfa5df8efd6cc0d83d7 100644 --- a/otbtf/__init__.py +++ b/otbtf/__init__.py @@ -20,7 +20,9 @@ """ OTBTF python module """ + from otbtf.utils import read_as_np_arr, gdal_open from otbtf.dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ - DatasetFromPatchesImages + DatasetFromPatchesImages from otbtf.tfrecords import TFRecords +from otbtf.model import ModelBase diff --git a/otbtf/dataset.py b/otbtf/dataset.py index 002754811458597db20a34233a83132ece5e7090..b7ca2025ab111031cca69d31c9fb6b8d563cd3fe 100644 --- a/otbtf/dataset.py +++ b/otbtf/dataset.py @@ -273,7 +273,7 @@ class PatchesImagesReader(PatchesReaderBase): "mean": rsize * _sums[src_key], "std": np.sqrt(rsize * _sqsums[src_key] - np.square(rsize * _sums[src_key])) } for src_key in self.gdal_ds} - logging.info("Stats: {}", stats) + logging.info("Stats: %s", stats) return stats def get_size(self): @@ -362,8 +362,8 @@ class Dataset: self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) - logging.info("output_types: {}", self.output_types) - logging.info("output_shapes: {}", self.output_shapes) + logging.info("output_types: %s", self.output_types) + logging.info("output_shapes: %s", self.output_shapes) # buffers if self.size <= buffer_length: @@ -462,17 +462,41 @@ class Dataset: for _ in range(self.size): yield self.read_one_sample() - def get_tf_dataset(self, batch_size, drop_remainder=True): + def get_tf_dataset(self, batch_size, drop_remainder=True, preprocessing_fn=None, targets_keys=None): """ Returns a TF dataset, ready to be used with the provided batch size :param batch_size: the batch size :param drop_remainder: drop incomplete batches + :param preprocessing_fn: Optional. A preprocessing function that takes input examples as args and returns the + preprocessed input examples. Typically, examples are composed of model inputs and + targets. Model inputs and model targets must be computed accordingly to (1) what the + model outputs and (2) what training loss needs. For instance, for a classification + problem, the model will likely output the softmax, or activation neurons, for each + class, and the cross entropy loss requires labels in one hot encoding. In this case, + the preprocessing_fn has to transform the labels values (integer ranging from + [0, n_classes]) in one hot encoding (vector of 0 and 1 of length n_classes). The + preprocessing_fn should not implement such things as radiometric transformations from + input to input_preprocessed, because those are performed inside the model itself + (see `otbtf.ModelBase.normalize_inputs()`). + :param targets_keys: Optional. When provided, the dataset returns a tuple of dicts (inputs_dict, target_dict) so + it can be straightforwardly used with keras models objects. :return: The TF dataset """ - if batch_size <= 2 * self.miner_buffer.max_length: - logging.warning("Batch size is {} but dataset buffer has {} elements. Consider using a larger dataset " + if 2 * batch_size >= self.miner_buffer.max_length: + logging.warning("Batch size is %s but dataset buffer has %s elements. Consider using a larger dataset " "buffer to avoid I/O bottleneck", batch_size, self.miner_buffer.max_length) - return self.tf_dataset.batch(batch_size, drop_remainder=drop_remainder) + tf_ds = self.tf_dataset.map(preprocessing_fn) if preprocessing_fn else self.tf_dataset + + if targets_keys: + def _split_input_and_target(example): + # Differentiating inputs and outputs for keras + inputs = {key: value for (key, value) in example.items() if key not in targets_keys} + targets = {key: value for (key, value) in example.items() if key in targets_keys} + return inputs, targets + + tf_ds = tf_ds.map(_split_input_and_target) + + return tf_ds.batch(batch_size, drop_remainder=drop_remainder) def get_total_wait_in_seconds(self): """ diff --git a/otbtf/examples/tensorflow_v2x/fcnn/README.md b/otbtf/examples/tensorflow_v2x/fcnn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e6cfce78fe606ff4d525829c302a3e5f8528a900 --- /dev/null +++ b/otbtf/examples/tensorflow_v2x/fcnn/README.md @@ -0,0 +1,64 @@ +This example show how to train a small fully convolutional model using the +OTBTF python API. In particular, the example show how a model can be trained +(1) from **patches-images**, or (2) from **TFRecords** files. + +# Files + +- `fcnn_model.py` implements a small fully convolutional U-Net like model, +with the preprocessing and normalization functions +- `train_from_patches-images.py` shows how to train the model from a list of +patches-images +- `train_from_tfrecords.py` shows how to train the model from TFRecords files +- `create_tfrecords.py` shows how to convert patch-images into TFRecords files +- `helper.py` contains a few helping functions + +# Patches-images vs TFRecords based datasets + +TensorFlow datasets are the most practical way to feed a network data during +training steps. +In particular, they are very useful to train models with data parallelism using +multiple workers (i.e. multiple GPU devices). +Since OTBTF 3, two kind of approaches are available to deliver the patches: +- Create TF datasets from **patches-images**: the first approach implemented in +OTBTF, relying on geospatial raster formats supported by GDAL. Patches are simply +stacked in rows. patches-images are friendly because they can be visualized +like any other image. However this approach is **not very optimized**, since it +generates a lot of I/O and stresses the filesystem when iterating randomly over +patches. +- Create TF datasets from **TFRecords** files. The principle is that a number of +patches are stored in TFRecords files (google protubuf serialized data). This +approach provides the best performances, since it generates less I/Os since +multiple patches are read simultaneously together. It is the recommended approach +to work on high end gear. It requires an additional step of converting the +patches-images into TFRecords files. + +## Patches-images based datasets + +**Patches-images** are generated from the `PatchesExtraction` application of OTBTF. +They consist in extracted patches stacked in rows into geospatial rasters. +The `otbtf.DatasetFromPatchesImages` provides access to **patches-images** as a +TF dataset. It inherits from the `otbtf.Dataset` class, which can be a base class +to develop other raster based datasets. +The `use_streaming` option can be used to read the patches on-the-fly +on the filesystem. However, this can cause I/O bottleneck when one training step +is shorter that fetching one batch of data. Typically, this is very common with +small networks trained over large amount of data using multiple GPUs, causing the +filesystem read operation being the weak point (and the GPUs wait for the batches +to be ready). The class offers other functionalities, for instance changing the +iterator class with a custom one (can inherit from `otbtf.dataset.IteratorBase`) +which is, by default, an `otbtf.dataset.RandomIterator`. This could enable to +control how the patches are walked, from the multiple patches-images of the +dataset. + +## TFRecords batches datasets + +**TFRecord** based datasets are implemented in the `otbtf.tfrecords` module. +They basically deliver patches from the TFRecords files, which can be created +with the `to_tfrecords()` method of the `otbtf.Dataset` based classes. +Depending on the filesystem characteristics and the computational cost of one +training step, it can be good to select the number of samples per TFRecords file. +Another tweak is the shuffling: since one TFRecord file contains multiple patches, +the way TFRecords files are accessed (sometimes, we need them to be randomly +accessed), and the way patches are accessed (within a buffer, of size set with the +`shuffle_buffer_size`), is crucial. + diff --git a/otbtf/examples/tensorflow_v2x/fcnn/create_tfrecords.py b/otbtf/examples/tensorflow_v2x/fcnn/create_tfrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..51043ef16b075238c36d8d90d636f6b1d6027fe1 --- /dev/null +++ b/otbtf/examples/tensorflow_v2x/fcnn/create_tfrecords.py @@ -0,0 +1,39 @@ +""" +This example shows how to convert patches-images (like the ones generated from the `PatchesExtraction`) +into TFRecords files. +""" +import argparse +from pathlib import Path +from otbtf.examples.tensorflow_v2x.fcnn import helper +from otbtf import DatasetFromPatchesImages + +parser = argparse.ArgumentParser(description="Converts patches-images into TFRecords") +parser.add_argument("--xs", required=True, nargs="+", default=[], help="A list of patches-images for the XS image") +parser.add_argument("--labels", required=True, nargs="+", default=[], + help="A list of patches-images for the labels") +parser.add_argument("--outdir", required=True, help="Output dir for TFRecords files") + + +def create_tfrecords(params): + # Sort patches and labels + patches = sorted(params.xs) + labels = sorted(params.labels) + + # Check patches and labels are correctly sorted + helper.check_files_order(patches, labels) + + # Create output directory + outdir = Path(params.outdir) + if not outdir.exists(): + outdir.mkdir(exist_ok=True) + + # Create dataset from the filename dict + dataset = DatasetFromPatchesImages(filenames_dict={"input_xs_patches": patches, "labels_patches": labels}) + + # Convert the dataset into TFRecords + dataset.to_tfrecords(output_dir=params.outdir, drop_remainder=False) + + +if __name__ == "__main__": + params = parser.parse_args() + create_tfrecords(params) diff --git a/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py b/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..95f2d0172a94ae4b536623fc6cb54e4de9affcfa --- /dev/null +++ b/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py @@ -0,0 +1,135 @@ +""" +Implementation of a small U-Net like model +""" +from otbtf.model import ModelBase +import tensorflow as tf +import logging + +logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') +N_CLASSES = 2 +INPUT_NAME = "input_xs" # name of the input in the `FCNNModel` instance, also name of the input node in the SavedModel +TARGET_NAME = "predictions" # name of the output in the `FCNNModel` instance +OUTPUT_SOFTMAX_NAME = "predictions_softmax_tensor" # name (prefix) of the output node in the SavedModel + + +class FCNNModel(ModelBase): + """ + A Simple Fully Convolutional U-Net like model + """ + + def normalize_inputs(self, inputs): + """ + Inherits from `ModelBase` + + The model will use this function internally to normalize its inputs, before applying the `get_outputs()` + function that actually builds the operations graph (convolutions, etc). + This function will hence work at training time and inference time. + + In this example, we assume that we have an input 12 bits multispectral image with values ranging from + [0, 10000], that we process using a simple stretch to roughly match the [0, 1] range. + + :param inputs: dict of inputs + :return: dict of normalized inputs, ready to be used from the `get_outputs()` function of the model + """ + return {INPUT_NAME: tf.cast(inputs[INPUT_NAME], tf.float32) * 0.0001} + + def get_outputs(self, normalized_inputs): + """ + Inherits from `ModelBase` + + This small model produces an output which has the same physical spacing as the input. + The model generates [1 x 1 x N_CLASSES] output pixel for [32 x 32 x <nb channels>] input pixels. + + :param normalized_inputs: dict of normalized inputs` + :return: activation values + """ + + norm_inp = normalized_inputs[INPUT_NAME] + + def _conv(inp, depth, name): + return tf.keras.layers.Conv2D(filters=depth, kernel_size=3, strides=2, activation="relu", + padding="same", name=name)(inp) + + def _tconv(inp, depth, name, activation="relu"): + return tf.keras.layers.Conv2DTranspose(filters=depth, kernel_size=3, strides=2, activation=activation, + padding="same", name=name)(inp) + + out_conv1 = _conv(norm_inp, 16, "conv1") + out_conv2 = _conv(out_conv1, 32, "conv2") + out_conv3 = _conv(out_conv2, 64, "conv3") + out_conv4 = _conv(out_conv3, 64, "conv4") + out_tconv1 = _tconv(out_conv4, 64, "tconv1") + out_conv3 + out_tconv2 = _tconv(out_tconv1, 32, "tconv2") + out_conv2 + out_tconv3 = _tconv(out_tconv2, 16, "tconv3") + out_conv1 + out_tconv4 = _tconv(out_tconv3, N_CLASSES, "classifier", None) + + # Generally it is a good thing to name the final layers of the network (i.e. the layers of which outputs are + # returned from the `MyModel.get_output()` method). + # Indeed this enables to retrieve them for inference time, using their name. + # In case your forgot to name the last layers, it is still possible to look at the model outputs using the + # `saved_model_cli show --dir /path/to/your/savedmodel --all` command. + # + # Do not confuse **the name of the output layers** (i.e. the "name" property of the tf.keras.layer that is used + # to generate an output tensor) and **the key of the output tensor**, in the dict returned from the + # `MyModel.get_output()` method. They are two identifiers with a different purpose: + # - the output layer name is used only at inference time, to identify the output tensor from which generate + # the output image, + # - the output tensor key identifies the output tensors, mainly to fit the targets to model outputs during + # training process, but it can also be used to access the tensors as tf/keras objects, for instance to + # display previews images in TensorBoard. + predictions = tf.keras.layers.Softmax(name=OUTPUT_SOFTMAX_NAME)(out_tconv4) + + return {TARGET_NAME: predictions} + + +def dataset_preprocessing_fn(examples): + """ + Preprocessing function for the training dataset. + This function is only used at training time, to put the data in the expected format for the training step. + DO NOT USE THIS FUNCTION TO NORMALIZE THE INPUTS ! (see `otbtf.ModelBase.normalize_inputs` for that). + Note that this function is not called here, but in the code that prepares the datasets. + + :param examples: dict for examples (i.e. inputs and targets stored in a single dict) + :return: preprocessed examples + """ + + def _to_categorical(x): + return tf.one_hot(tf.squeeze(tf.cast(x, tf.int32), axis=-1), depth=N_CLASSES) + + return {INPUT_NAME: examples["input_xs_patches"], + TARGET_NAME: _to_categorical(examples["labels_patches"])} + + +def train(params, ds_train, ds_valid, ds_test): + """ + Create, train, and save the model. + + :param params: contains batch_size, learning_rate, nb_epochs, and model_dir + :param ds_train: training dataset + :param ds_valid: validation dataset + :param ds_test: testing dataset + """ + + strategy = tf.distribute.MirroredStrategy() # For single or multi-GPUs + with strategy.scope(): + # Model instantiation. Note that the normalize_fn is now part of the model + # It is mandatory to instantiate the model inside the strategy scope. + model = FCNNModel(dataset_element_spec=ds_train.element_spec) + + # Compile the model + model.compile(loss=tf.keras.losses.CategoricalCrossentropy(), + optimizer=tf.keras.optimizers.Adam(learning_rate=params.learning_rate), + metrics=[tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]) + + # Summarize the model (in CLI) + model.summary() + + # Train + model.fit(ds_train, epochs=params.nb_epochs, validation_data=ds_valid) + + # Evaluate against test data + if ds_test is not None: + model.evaluate(ds_test, batch_size=params.batch_size) + + # Save trained model as SavedModel + model.save(params.model_dir) diff --git a/otbtf/examples/tensorflow_v2x/fcnn/helper.py b/otbtf/examples/tensorflow_v2x/fcnn/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..aea3a0ac61f1253c8fecc322295d8dc4beb56702 --- /dev/null +++ b/otbtf/examples/tensorflow_v2x/fcnn/helper.py @@ -0,0 +1,37 @@ +""" +A set of helpers for the examples +""" +import argparse + + +def base_parser(): + """ + Create a parser with the base parameters for the training applications + + :return: argparse.ArgumentParser instance + """ + parser = argparse.ArgumentParser(description="Train a FCNN model") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size") + parser.add_argument("--learning_rate", type=float, default=0.0001, help="Learning rate") + parser.add_argument("--nb_epochs", type=int, default=100, help="Number of epochs") + parser.add_argument("--model_dir", required=True, help="Path to save model") + return parser + + +def check_files_order(files1, files2): + """ + Here we check that the two input lists of str are correctly sorted. + Except for the last, splits of files1[i] and files2[i] from the "_" character, must be equal. + + :param files1: list of filenames (str) + :param files2: list of filenames (str) + """ + assert files1 + assert files2 + assert len(files1) == len(files2) + + def get_basename(n): + return "_".join([n.split("_")][:-1]) + + for p, l in zip(files1, files2): + assert get_basename(p) == get_basename(l) diff --git a/otbtf/examples/tensorflow_v2x/fcnn/train_from_patchesimages.py b/otbtf/examples/tensorflow_v2x/fcnn/train_from_patchesimages.py new file mode 100644 index 0000000000000000000000000000000000000000..9299c9e009986aa39d3bf4de8ddd24efff49fb04 --- /dev/null +++ b/otbtf/examples/tensorflow_v2x/fcnn/train_from_patchesimages.py @@ -0,0 +1,62 @@ +""" +This example shows how to use the otbtf python API to train a deep net from patches-images. +""" +from otbtf import DatasetFromPatchesImages +from otbtf.examples.tensorflow_v2x.fcnn import helper +from otbtf.examples.tensorflow_v2x.fcnn import fcnn_model + +parser = helper.base_parser() +parser.add_argument("--train_xs", required=True, nargs="+", default=[], + help="A list of patches-images for the XS image (training dataset)") +parser.add_argument("--train_labels", required=True, nargs="+", default=[], + help="A list of patches-images for the labels (training dataset)") +parser.add_argument("--valid_xs", required=True, nargs="+", default=[], + help="A list of patches-images for the XS image (validation dataset)") +parser.add_argument("--valid_labels", required=True, nargs="+", default=[], + help="A list of patches-images for the labels (validation dataset)") +parser.add_argument("--test_xs", required=False, nargs="+", default=[], + help="A list of patches-images for the XS image (test dataset)") +parser.add_argument("--test_labels", required=False, nargs="+", default=[], + help="A list of patches-images for the labels (test dataset)") + + +def create_dataset(xs_filenames, labels_filenames, batch_size, targets_keys=[fcnn_model.TARGET_NAME]): + """ + Returns a TF dataset generated from an `otbtf.DatasetFromPatchesImages` instance + """ + # Sort patches and labels + xs_filenames.sort() + labels_filenames.sort() + + # Check patches and labels are correctly sorted + helper.check_files_order(xs_filenames, labels_filenames) + + # Create dataset from the filename dict + # You can add the `use_streaming` option here, is you want to lower the memory budget. + # However, this can slow down your process since the patches are read on-the-fly on the filesystem. + # Good when one batch computation is slower than one batch gathering! + # You can also use a custom `Iterator` of your own (default is `RandomIterator`). See `otbtf.dataset.Iterator`. + ds = DatasetFromPatchesImages(filenames_dict={"input_xs_patches": xs_filenames, "labels_patches": labels_filenames}) + + # We generate the TF dataset, and we use a preprocessing option to put the labels into one hot encoding (see the + # `fcnn_model.dataset_preprocessing_fn` function). Also, we set the `target_keys` parameter to ask the dataset to + # deliver samples in the form expected by keras, i.e. a tuple of dicts (inputs_dict, target_dict). + tf_ds = ds.get_tf_dataset(batch_size=batch_size, preprocessing_fn=fcnn_model.dataset_preprocessing_fn, + targets_keys=targets_keys) + + return tf_ds + + +def train(params): + # Create TF datasets + ds_train = create_dataset(params.train_xs, params.train_labels, batch_size=params.batch_size) + ds_valid = create_dataset(params.valid_xs, params.valid_labels, batch_size=params.batch_size) + ds_test = create_dataset(params.test_xs, params.test_labels, + batch_size=params.batch_size) if params.test_xs else None + + # Train the model + fcnn_model.train(params, ds_train, ds_valid, ds_test) + + +if __name__ == "__main__": + train(parser.parse_args()) diff --git a/otbtf/examples/tensorflow_v2x/fcnn/train_from_tfrecords.py b/otbtf/examples/tensorflow_v2x/fcnn/train_from_tfrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..3fbfe4720e6357aba5485183d3b063ca09cad638 --- /dev/null +++ b/otbtf/examples/tensorflow_v2x/fcnn/train_from_tfrecords.py @@ -0,0 +1,61 @@ +""" +This example shows how to use the otbtf python API to train a deep net from TFRecords. + +We expect that the files are stored in the following way, with m, n, and k denoting respectively +the number of TFRecords files in the training, validation, and test datasets: + +/dataset_dir + /train + 1.records + 2.records + ... + m.records + /valid + 1.records + 2.records + ... + n.records + /test + 1.records + 2.records + ... + k.records + +""" +import os +from otbtf import TFRecords +from otbtf.examples.tensorflow_v2x.fcnn import helper +from otbtf.examples.tensorflow_v2x.fcnn import fcnn_model + +parser = helper.base_parser() +parser.add_argument("--tfrecords_dir", required=True, + help="Directory containing train, valid(, test) folders of TFRecords files") + + +def train(params): + # Patches directories must contain 'train' and 'valid' dirs ('test' is not required) + train_dir = os.path.join(params.tfrecords_dir, "train") + valid_dir = os.path.join(params.tfrecords_dir, "valid") + test_dir = os.path.join(params.tfrecords_dir, "test") + + kwargs = {"batch_size": params.batch_size, + "target_keys": [fcnn_model.TARGET_NAME], + "preprocessing_fn": fcnn_model.dataset_preprocessing_fn} + + # Training dataset. Must be shuffled + assert os.path.isdir(train_dir) + ds_train = TFRecords(train_dir).read(shuffle_buffer_size=1000, **kwargs) + + # Validation dataset + assert os.path.isdir(valid_dir) + ds_valid = TFRecords(valid_dir).read(**kwargs) + + # Test dataset (optional) + ds_test = TFRecords(test_dir).read(**kwargs) if os.path.isdir(test_dir) else None + + # Train the model + fcnn_model.train(params, ds_train, ds_valid, ds_test) + + +if __name__ == "__main__": + train(parser.parse_args()) diff --git a/otbtf/model.py b/otbtf/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9fd3452a9aab26049e33b9d13b24f82d3aacd3 --- /dev/null +++ b/otbtf/model.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +""" Base class for models""" +import abc +import logging +import tensorflow +from otbtf.utils import _is_chief, cropped_tensor_name + + +class ModelBase(abc.ABC): + """ + Base class for all models + """ + + def __init__(self, dataset_element_spec, input_keys=None, inference_cropping=None): + """ + Model initializer, must be called **inside** the strategy.scope(). + + :param dataset_element_spec: the dataset elements specification (shape, dtype, etc). Can be retrieved from the + dataset instance simply with `ds.element_spec` + :param input_keys: Optional. the keys of the inputs used in the model. If not specified, all inputs from the + dataset will be considered. + :param inference_cropping: list of number of pixels to be removed on each side of the output during inference. + This list creates some additional outputs in the model, not used during training, + only during inference. Default [16, 32, 64, 96, 128] + """ + # Retrieve dataset inputs shapes + dataset_input_element_spec = dataset_element_spec[0] + logging.info("Dataset input element spec: %s", dataset_input_element_spec) + + if input_keys: + self.dataset_input_keys = input_keys + logging.info("Using input keys: %s", self.dataset_input_keys) + else: + self.dataset_input_keys = list(dataset_input_element_spec) + logging.info("Found dataset input keys: %s", self.dataset_input_keys) + + self.inputs_shapes = {key: dataset_input_element_spec[key].shape[1:] for key in self.dataset_input_keys} + logging.info("Inputs shapes: %s", self.inputs_shapes) + + # Setup cropping, normalization function + self.inference_cropping = [16, 32, 64, 96, 128] if not inference_cropping else inference_cropping + logging.info("Inference cropping values: %s", self.inference_cropping) + + # Create model + self.model = self.create_network() + + def __getattr__(self, name): + """This method is called when the default attribute access fails. We choose to try to access the attribute of + self.model. Thus, any method of keras.Model() can be used transparently, e.g. model.summary() or model.fit()""" + return getattr(self.model, name) + + def get_inputs(self): + """ + This method returns the dict of keras.Input + """ + # Create Keras inputs + model_inputs = {} + for key in self.dataset_input_keys: + new_shape = list(self.inputs_shapes[key]) + logging.info("Original shape for input %s: %s", key, new_shape) + # Here we modify the x and y dims of >2D tensors to enable any image size at input + if len(new_shape) > 2: + new_shape[0] = None + new_shape[1] = None + placeholder = tensorflow.keras.Input(shape=new_shape, name=key) + logging.info("New shape for input %s: %s", key, new_shape) + model_inputs.update({key: placeholder}) + return model_inputs + + @abc.abstractmethod + def get_outputs(self, normalized_inputs): + """ + Implementation of the model, from the normalized inputs. + + :param normalized_inputs: normalized inputs, as generated from `self.normalize_inputs()` + :return: dict of model outputs + """ + raise NotImplementedError("This method has to be implemented. Here you code the model :)") + + def normalize_inputs(self, inputs): + """ + Normalize the model inputs. + Takes the dict of inputs and returns a dict of normalized inputs. + + :param inputs: model inputs + :return: a dict of normalized model inputs + """ + logging.warning("normalize_input() undefined. No normalization of the model inputs will be performed. " + "You can implement the function in your model class if you want.") + return inputs + + def postprocess_outputs(self, outputs, inputs=None, normalized_inputs=None): + """ + Post-process the model outputs. + Takes the dicts of inputs and outputs, and returns a dict of post-processed outputs. + The default implementation provides a set of cropped output tensors + + :param outputs: dict of model outputs + :param inputs: dict of model inputs (optional) + :param normalized_inputs: dict of normalized model inputs (optional) + :return: a dict of post-processed model outputs + """ + + # Add extra outputs for inference + extra_outputs = {} + for out_key, out_tensor in outputs.items(): + for crop in self.inference_cropping: + extra_output_key = cropped_tensor_name(out_key, crop) + extra_output_name = cropped_tensor_name(out_tensor._keras_history.layer.name, crop) + logging.info("Adding extra output for tensor %s with crop %s (%s)", out_key, crop, extra_output_name) + cropped = out_tensor[:, crop:-crop, crop:-crop, :] + identity = tensorflow.keras.layers.Activation('linear', name=extra_output_name) + extra_outputs[extra_output_key] = identity(cropped) + + return extra_outputs + + def create_network(self): + """ + This method returns the Keras model. This needs to be called **inside** the strategy.scope(). + Can be reimplemented depending on the needs. + + :return: the keras model + """ + + # Get the model inputs + inputs = self.get_inputs() + logging.info("Model inputs: %s", inputs) + + # Normalize the inputs + normalized_inputs = self.normalize_inputs(inputs=inputs) + logging.info("Normalized model inputs: %s", normalized_inputs) + + # Build the model + outputs = self.get_outputs(normalized_inputs=normalized_inputs) + logging.info("Model outputs: %s", outputs) + + # Post-processing for inference + postprocessed_outputs = self.postprocess_outputs(outputs=outputs, inputs=inputs, + normalized_inputs=normalized_inputs) + outputs.update(postprocessed_outputs) + + # Return the keras model + return tensorflow.keras.Model(inputs=inputs, outputs=outputs, name=self.__class__.__name__) + + def summary(self, strategy=None): + """ + Wraps the summary printing of the model. When multiworker strategy, only prints if the worker is chief + """ + if not strategy or _is_chief(strategy): + self.model.summary(line_length=150) + + def plot(self, output_path, strategy=None): + """ + Enables to save a figure representing the architecture of the network. + Needs pydot and graphviz to work (`pip install pydot` and https://graphviz.gitlab.io/download/) + """ + assert self.model, "Plot() only works if create_network() has been called beforehand" + + # When multiworker strategy, only plot if the worker is chief + if not strategy or _is_chief(strategy): + # Build a simplified model, without normalization nor extra outputs. + # This model is only used for plotting the architecture thanks to `keras.utils.plot_model` + inputs = self.get_inputs() # inputs without normalization + outputs = self.get_outputs(inputs) # raw model outputs + model_simplified = tensorflow.keras.Model(inputs=inputs, outputs=outputs, + name=self.__class__.__name__ + '_simplified') + tensorflow.keras.utils.plot_model(model_simplified, output_path) + diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py index b2aae0b2fdc7f1a0341840b447c121694b8c8367..123bdea554d07eb537fcd7455747db00bd7ce110 100644 --- a/otbtf/tfrecords.py +++ b/otbtf/tfrecords.py @@ -41,8 +41,8 @@ class TFRecords: self.dirpath = path os.makedirs(self.dirpath, exist_ok=True) self.output_types_file = os.path.join(self.dirpath, "output_types.json") - self.output_shape_file = os.path.join(self.dirpath, "output_shape.json") - self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None + self.output_shapes_file = os.path.join(self.dirpath, "output_shapes.json") + self.output_shapes = self.load(self.output_shapes_file) if os.path.exists(self.output_shapes_file) else None self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None @staticmethod @@ -70,8 +70,8 @@ class TFRecords: if not drop_remainder and dataset.size % n_samples_per_shard > 0: nb_shards += 1 - output_shapes = {key: (None,) + output_shape for key, output_shape in dataset.output_shapes.items()} - self.save(output_shapes, self.output_shape_file) + output_shapes = {key: output_shape for key, output_shape in dataset.output_shapes.items()} + self.save(output_shapes, self.output_shapes_file) output_types = {key: output_type.name for key, output_type in dataset.output_types.items()} self.save(output_types, self.output_types_file) @@ -101,7 +101,6 @@ class TFRecords: :param data: Data to save json format :param filepath: Output file name """ - with open(filepath, 'w') as file: json.dump(data, file, indent=4) @@ -114,34 +113,38 @@ class TFRecords: with open(filepath, 'r') as file: return json.load(file) - @staticmethod - def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): + def parse_tfrecord(self, example, target_keys, preprocessing_fn=None, **kwargs): """ Parse example object to sample dict. :param example: Example object to parse - :param features_types: List of types for each feature :param target_keys: list of keys of the targets - :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns - a tuple (input_preprocessed, target_preprocessed) + :param preprocessing_fn: Optional. A preprocessing function that process the input example :param kwargs: some keywords arguments for preprocessing_fn """ - read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} + read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in self.output_types} example_parsed = tf.io.parse_single_example(example, read_features) - for key in read_features.keys(): - example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key]) + # Tensor with right data type + for key, out_type in self.output_types.items(): + example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=out_type) + + # Ensure shape + for key, shape in self.output_shapes.items(): + example_parsed[key] = tf.ensure_shape(example_parsed[key], shape) - # Differentiating inputs and outputs - input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} - target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + # Preprocessing + example_parsed_prep = preprocessing_fn(example_parsed, **kwargs) if preprocessing_fn else example_parsed - if preprocessing_fn: - input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs) + # Differentiating inputs and targets + input_parsed = {key: value for (key, value) in example_parsed_prep.items() if key not in target_keys} + target_parsed = {key: value for (key, value) in example_parsed_prep.items() if key in target_keys} return input_parsed, target_parsed def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None, - preprocessing_fn=None, **kwargs): + preprocessing_fn=None, shard_policy=tf.data.experimental.AutoShardPolicy.AUTO, + prefetch_buffer_size=tf.data.experimental.AUTOTUNE, + num_parallel_calls=tf.data.experimental.AUTOTUNE, **kwargs): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch @@ -153,18 +156,28 @@ class TFRecords: False is advisable when evaluating metrics so that all samples are used :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size elements are shuffled using uniform random. - :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns - a tuple (input_preprocessed, target_preprocessed) + :param preprocessing_fn: Optional. A preprocessing function that takes input examples as args and returns the + preprocessed input examples. Typically, examples are composed of model inputs and + targets. Model inputs and model targets must be computed accordingly to (1) what the + model outputs and (2) what training loss needs. For instance, for a classification + problem, the model will likely output the softmax, or activation neurons, for each + class, and the cross entropy loss requires labels in one hot encoding. In this case, + the preprocessing_fn has to transform the labels values (integer ranging from + [0, n_classes]) in one hot encoding (vector of 0 and 1 of length n_classes). The + preprocessing_fn should not implement such things as radiometric transformations from + input to input_preprocessed, because those are performed inside the model itself + (see `otbtf.ModelBase.normalize_inputs()`). + :param shard_policy: sharding policy for the TFRecordDataset options + :param prefetch_buffer_size: buffer size for the prefetch operation + :param num_parallel_calls: number of parallel calls for the parsing + preprocessing step :param kwargs: some keywords arguments for preprocessing_fn """ options = tf.data.Options() if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed - options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, - preprocessing_fn=preprocessing_fn, **kwargs) + options.experimental_distribute.auto_shard_policy = shard_policy # for multiworker + parse = partial(self.parse_tfrecord, target_keys=target_keys, preprocessing_fn=preprocessing_fn, **kwargs) - # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ? tfrecords_pattern_path = os.path.join(self.dirpath, "*.records") @@ -179,10 +192,10 @@ class TFRecords: logging.info('Reducing number of records to : %s', nb_matching_files) dataset = tf.data.TFRecordDataset(matching_files) # , num_parallel_reads=2) # interleaves reads from xxx files dataset = dataset.with_options(options) # uses data as soon as it streams in, rather than in its original order - dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) + dataset = dataset.map(parse, num_parallel_calls=num_parallel_calls) if shuffle_buffer_size: dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) - dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + dataset = dataset.prefetch(buffer_size=prefetch_buffer_size) return dataset diff --git a/otbtf/utils.py b/otbtf/utils.py index 069638a551e0226c9ff3ad6cc1537359922b29c1..7aa777af75e3b92bdacde6d28c3b1b1e6072de4b 100644 --- a/otbtf/utils.py +++ b/otbtf/utils.py @@ -63,3 +63,37 @@ def read_as_np_arr(gdal_ds, as_patches=True, dtype=None): buffer = buffer.astype(dtype) return buffer + + +def _is_chief(strategy): + """ + Tell if the current worker is the chief. + + :param strategy: strategy + :return: True if the current worker is the chief, False else + """ + # Note: there are two possible `TF_CONFIG` configuration. + # 1) In addition to `worker` tasks, a `chief` task type is use; + # in this case, this function should be modified to + # `return task_type == 'chief'`. + # 2) Only `worker` task type is used; in this case, worker 0 is + # regarded as the chief. The implementation demonstrated here + # is for this case. + # For the purpose of this Colab section, the `task_type is None` case + # is added because it is effectively run with only a single worker. + + if strategy.cluster_resolver: # this means MultiWorkerMirroredStrategy + task_type, task_id = strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id + return (task_type == 'chief') or (task_type == 'worker' and task_id == 0) or task_type is None + # strategy with only one worker + return True + + +def cropped_tensor_name(tensor_name, crop): + """ + A name for the padded tensor + :param tensor_name: tensor name + :param pad: pad value + :return: name + """ + return "{}_crop{}".format(tensor_name, crop) diff --git a/test/api_unittest.py b/test/api_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..295824895633ad3ed8c24346a92dd9622f1e2d1b --- /dev/null +++ b/test/api_unittest.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import pytest +import unittest +from test_utils import resolve_paths, files_exist, run_command_and_compare +from otbtf.examples.tensorflow_v2x.fcnn.fcnn_model import INPUT_NAME, OUTPUT_SOFTMAX_NAME +from otbtf.examples.tensorflow_v2x.fcnn import train_from_patchesimages +from otbtf.examples.tensorflow_v2x.fcnn import train_from_tfrecords +from otbtf.examples.tensorflow_v2x.fcnn import create_tfrecords +from otbtf.model import cropped_tensor_name + +INFERENCE_MAE_TOL = 10.0 # Dummy value: we don't really care of the mae value but rather the image size etc + + +class APITest(unittest.TestCase): + + @pytest.mark.order(1) + def test_train_from_patchesimages(self): + params = train_from_patchesimages.parser.parse_args(['--model_dir', resolve_paths('$TMPDIR/model_from_pimg'), + '--nb_epochs', '1', + '--train_xs', + resolve_paths('$DATADIR/amsterdam_patches_A.tif'), + '--train_labels', + resolve_paths('$DATADIR/amsterdam_labels_A.tif'), + '--valid_xs', + resolve_paths('$DATADIR/amsterdam_patches_B.tif'), + '--valid_labels', + resolve_paths('$DATADIR/amsterdam_labels_B.tif')]) + train_from_patchesimages.train(params=params) + self.assertTrue(files_exist(['$TMPDIR/model_from_pimg/keras_metadata.pb', + '$TMPDIR/model_from_pimg/saved_model.pb', + '$TMPDIR/model_from_pimg/variables/variables.data-00000-of-00001', + '$TMPDIR/model_from_pimg/variables/variables.index'])) + + @pytest.mark.order(2) + def test_model_inference1(self): + self.assertTrue( + run_command_and_compare( + command= + "otbcli_TensorflowModelServe " + "-source1.il $DATADIR/fake_spot6.jp2 " + "-source1.rfieldx 64 " + "-source1.rfieldy 64 " + f"-source1.placeholder {INPUT_NAME} " + "-model.dir $TMPDIR/model_from_pimg " + "-model.fullyconv on " + f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 16)} " + "-output.efieldx 32 " + "-output.efieldy 32 " + "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8", + to_compare_dict={"$DATADIR/classif_model4_softmax.tif": "$TMPDIR/classif_model4_softmax.tif"}, + tol=INFERENCE_MAE_TOL)) + self.assertTrue( + run_command_and_compare( + command= + "otbcli_TensorflowModelServe " + "-source1.il $DATADIR/fake_spot6.jp2 " + "-source1.rfieldx 128 " + "-source1.rfieldy 128 " + f"-source1.placeholder {INPUT_NAME} " + "-model.dir $TMPDIR/model_from_pimg " + "-model.fullyconv on " + f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 32)} " + "-output.efieldx 64 " + "-output.efieldy 64 " + "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8", + to_compare_dict={"$DATADIR/classif_model4_softmax.tif": "$TMPDIR/classif_model4_softmax.tif"}, + tol=INFERENCE_MAE_TOL)) + + @pytest.mark.order(3) + def test_create_tfrecords(self): + params = create_tfrecords.parser.parse_args(['--xs', resolve_paths('$DATADIR/amsterdam_patches_A.tif'), + '--labels', resolve_paths('$DATADIR/amsterdam_labels_A.tif'), + '--outdir', resolve_paths('$TMPDIR/train')]) + create_tfrecords.create_tfrecords(params=params) + self.assertTrue(files_exist(['$TMPDIR/train/output_shapes.json', + '$TMPDIR/train/output_types.json', + '$TMPDIR/train/0.records'])) + params = create_tfrecords.parser.parse_args(['--xs', resolve_paths('$DATADIR/amsterdam_patches_B.tif'), + '--labels', resolve_paths('$DATADIR/amsterdam_labels_B.tif'), + '--outdir', resolve_paths('$TMPDIR/valid')]) + create_tfrecords.create_tfrecords(params=params) + self.assertTrue(files_exist(['$TMPDIR/valid/output_shapes.json', + '$TMPDIR/valid/output_types.json', + '$TMPDIR/valid/0.records'])) + + @pytest.mark.order(4) + def test_train_from_tfrecords(self): + params = train_from_tfrecords.parser.parse_args(['--model_dir', resolve_paths('$TMPDIR/model_from_tfrecs'), + '--nb_epochs', '1', + '--tfrecords_dir', resolve_paths('$TMPDIR')]) + train_from_tfrecords.train(params=params) + self.assertTrue(files_exist(['$TMPDIR/model_from_tfrecs/keras_metadata.pb', + '$TMPDIR/model_from_tfrecs/saved_model.pb', + '$TMPDIR/model_from_tfrecs/variables/variables.data-00000-of-00001', + '$TMPDIR/model_from_tfrecs/variables/variables.index'])) + + @pytest.mark.order(5) + def test_model_inference2(self): + self.assertTrue( + run_command_and_compare( + command= + "otbcli_TensorflowModelServe " + "-source1.il $DATADIR/fake_spot6.jp2 " + "-source1.rfieldx 64 " + "-source1.rfieldy 64 " + f"-source1.placeholder {INPUT_NAME} " + "-model.dir $TMPDIR/model_from_pimg " + "-model.fullyconv on " + f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 16)} " + "-output.efieldx 32 " + "-output.efieldy 32 " + "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8", + to_compare_dict={"$DATADIR/classif_model4_softmax.tif": "$TMPDIR/classif_model4_softmax.tif"}, + tol=INFERENCE_MAE_TOL)) + + self.assertTrue( + run_command_and_compare( + command= + "otbcli_TensorflowModelServe " + "-source1.il $DATADIR/fake_spot6.jp2 " + "-source1.rfieldx 128 " + "-source1.rfieldy 128 " + f"-source1.placeholder {INPUT_NAME} " + "-model.dir $TMPDIR/model_from_pimg " + "-model.fullyconv on " + f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 32)} " + "-output.efieldx 64 " + "-output.efieldy 64 " + "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8", + to_compare_dict={"$DATADIR/classif_model4_softmax.tif": "$TMPDIR/classif_model4_softmax.tif"}, + tol=INFERENCE_MAE_TOL)) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/data/classif_model4_softmax.tif b/test/data/classif_model4_softmax.tif new file mode 100644 index 0000000000000000000000000000000000000000..eb5ff03c8739d91b8866de470f6004b216856106 --- /dev/null +++ b/test/data/classif_model4_softmax.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a2ef82781b7e42c82be069db085e7dcc6a3c5b9db84c1277153fe730fd52741 +size 9504 diff --git a/test/test_utils.py b/test/test_utils.py index c07301e91bb2eab4a808d53beaa46396c3ea82fc..4554e28e3093e82d17faddf02ca0e14ad5536be1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,6 @@ import otbApplication import os +from pathlib import Path def get_nb_of_channels(raster): @@ -42,15 +43,59 @@ def compare(raster1, raster2, tol=0.01): return True -def resolve_paths(filename, var_list): +def resolve_paths(path): """ - Retrieve environment variables in paths - :param filename: file name - :params var_list: variable list - :return filename with retrieved environment variables + Resolve a path with the environment variables """ - new_filename = filename - for var in var_list: - new_filename = new_filename.replace("${}".format(var), os.environ[var]) - print("Resolve filename...\n\tfilename: {}, \n\tnew filename: {}".format(filename, new_filename)) - return new_filename + return os.path.expandvars(path) + + +def files_exist(file_list): + """ + Check is all files exist + """ + print("Checking if files exist...") + for file in file_list: + print("\t{}".format(file)) + path = Path(resolve_paths(file)) + if not path.is_file(): + print("File {} does not exist!".format(file)) + return False + print("\tOk") + return True + + +def run_command(command): + """ + Run a command + :param command: the command to run + """ + full_command = resolve_paths(command) + print("Running command: \n\t {}".format(full_command)) + os.system(full_command) + + +def run_command_and_test_exist(command, file_list): + """ + :param command: the command to run (str) + :param file_list: list of files to check + :return True or False + """ + run_command(command) + return files_exist(file_list) + + +def run_command_and_compare(command, to_compare_dict, tol=0.01): + """ + :param command: the command to run (str) + :param to_compare_dict: a dict of {baseline1: output1, ..., baselineN: outputN} + :param tol: tolerance (float) + :return True or False + """ + + run_command(command) + for baseline, output in to_compare_dict.items(): + if not compare(resolve_paths(baseline), resolve_paths(output), tol): + print("Baseline {} and output {} differ.".format(baseline, output)) + return False + return True diff --git a/test/tutorial_unittest.py b/test/tutorial_unittest.py index 7934862f348326de13a11eef4c81a4fe6512cec6..af2b181c8442f8c7033b0e7ea94f39d03d262efc 100644 --- a/test/tutorial_unittest.py +++ b/test/tutorial_unittest.py @@ -2,64 +2,11 @@ # -*- coding: utf-8 -*- import pytest import unittest -import os -from pathlib import Path -import test_utils +from test_utils import run_command, run_command_and_test_exist, run_command_and_compare INFERENCE_MAE_TOL = 10.0 # Dummy value: we don't really care of the mae value but rather the image size etc -def resolve_paths(path): - """ - Resolve a path with the environment variables - """ - return test_utils.resolve_paths(path, var_list=["TMPDIR", "DATADIR"]) - - -def run_command(command): - """ - Run a command - :param command: the command to run - """ - full_command = resolve_paths(command) - print("Running command: \n\t {}".format(full_command)) - os.system(full_command) - - -def run_command_and_test_exist(command, file_list): - """ - :param command: the command to run (str) - :param file_list: list of files to check - :return True or False - """ - run_command(command) - print("Checking if files exist...") - for file in file_list: - print("\t{}".format(file)) - path = Path(resolve_paths(file)) - if not path.is_file(): - print("File {} does not exist!".format(file)) - return False - print("\tOk") - return True - - -def run_command_and_compare(command, to_compare_dict, tol=0.01): - """ - :param command: the command to run (str) - :param to_compare_dict: a dict of {baseline1: output1, ..., baselineN: outputN} - :param tol: tolerance (float) - :return True or False - """ - - run_command(command) - for baseline, output in to_compare_dict.items(): - if not test_utils.compare(resolve_paths(baseline), resolve_paths(output), tol): - print("Baseline {} and output {} differ.".format(baseline, output)) - return False - return True - - class TutorialTest(unittest.TestCase): @pytest.mark.order(1) diff --git a/tools/docker/README.md b/tools/docker/README.md index 4246b74d33b88ee1daa38c2f3de7ab44911f0917..dc718e683254d15637dbd88190cbbe546a9ee5b1 100644 --- a/tools/docker/README.md +++ b/tools/docker/README.md @@ -111,7 +111,7 @@ If you see OOM errors during SuperBuild you should decrease CPU_RATIO (e.g. 0.75 ## Container examples ```bash # Pull GPU image and create a new container with your home directory as volume (requires apt package nvidia-docker2 and CUDA>=11.0) -docker create --gpus=all --volume $HOME:/home/otbuser/volume -it --name otbtf-gpu mdl4eo/otbtf2.4:gpu +docker create --gpus=all --volume $HOME:/home/otbuser/volume -it --name otbtf-gpu mdl4eo/otbtf:3.3.0-gpu # Run interactive docker start -i otbtf-gpu @@ -123,7 +123,7 @@ docker exec otbtf-gpu python -c 'import tensorflow as tf; print(tf.test.is_gpu_a ### Rebuild OTB with more modules ```bash -docker create --gpus=all -it --name otbtf-gpu-dev mdl4eo/otbtf2.4:gpu-dev +docker create --gpus=all -it --name otbtf-gpu-dev mdl4eo/otbtf:3.3.0-gpu-dev docker start -i otbtf-gpu-dev ``` ```bash diff --git a/tools/docker/build-deps-cli.txt b/tools/docker/build-deps-cli.txt index 5d699cb19db6cd4845acaa909f50e148a172e318..ffd72911c2ef56105a24e0c83f764aa9af4d7540 100644 --- a/tools/docker/build-deps-cli.txt +++ b/tools/docker/build-deps-cli.txt @@ -25,8 +25,6 @@ wget zip bison -gdal-bin -python3-gdal libboost-date-time-dev libboost-filesystem-dev libboost-graph-dev @@ -36,8 +34,6 @@ libboost-thread-dev libcurl4-gnutls-dev libexpat1-dev libfftw3-dev -libgdal-dev -libgeotiff-dev libgsl-dev libinsighttoolkit4-dev libkml-dev @@ -45,9 +41,6 @@ libmuparser-dev libmuparserx-dev libopencv-core-dev libopencv-ml-dev -libopenthreads-dev -libossim-dev -libpng-dev libsvm-dev libtinyxml-dev zlib1g-dev diff --git a/tools/docker/build-flags-otb.txt b/tools/docker/build-flags-otb.txt index 2c3e0feac4e480cd9f8a0c9969b70c301761e8da..def7bd2b8847834ad434ff31c7ff4eb5c6aaaeb7 100644 --- a/tools/docker/build-flags-otb.txt +++ b/tools/docker/build-flags-otb.txt @@ -3,9 +3,9 @@ -DUSE_SYSTEM_EXPAT=ON -DUSE_SYSTEM_FFTW=ON -DUSE_SYSTEM_FREETYPE=ON --DUSE_SYSTEM_GDAL=ON +-DUSE_SYSTEM_GDAL=OFF -DUSE_SYSTEM_GEOS=ON --DUSE_SYSTEM_GEOTIFF=ON +-DUSE_SYSTEM_GEOTIFF=OFF -DUSE_SYSTEM_GLEW=ON -DUSE_SYSTEM_GLFW=ON -DUSE_SYSTEM_GLUT=ON @@ -16,8 +16,6 @@ -DUSE_SYSTEM_MUPARSER=ON -DUSE_SYSTEM_MUPARSERX=ON -DUSE_SYSTEM_OPENCV=ON --DUSE_SYSTEM_OPENTHREADS=ON --DUSE_SYSTEM_OSSIM=ON -DUSE_SYSTEM_PNG=ON -DUSE_SYSTEM_QT5=ON -DUSE_SYSTEM_QWT=ON