diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..6996832664a568840544ffb055e381a2829e2cd8 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,37 @@ +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..eee13ac82659361f18f3595742fd0355de307e39 --- /dev/null +++ b/.gitignore @@ -0,0 +1,423 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml + +# Python +*.egg-info/ +/build + +# MoGe +/data* +/download +/extract +/debug +/workspace +/mlruns +/infer_output +/video_output +/eval_output +/.blobcache +/test_images +/test_videos +/vis +/videos +/blobmnt +/eval_dump +/pretrained +/.gradio +/tmp \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..b6caf5dc5366eb9322d9a9e7f4f7efb46da5a3d1 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,32 @@ +## 2024-11-28 +### Added +- Supported user-provided camera FOV. See [scripts/infer.py](scripts/infer.py) --fov_x. + - Related issues: [#25](https://github.com/microsoft/MoGe/issues/25) and [#24](https://github.com/microsoft/MoGe/issues/24). +- Added inference scripts for panorama images. See [scripts/infer_panorama.py](scripts/infer_panorama.py). + - Related issue: [#19](https://github.com/microsoft/MoGe/issues/19). + +### Fixed +- Suppressed unnecessary numpy runtime warnings. +- Specified recommended versions of requirements. + - Related issue: [#21](https://github.com/microsoft/MoGe/issues/21). + +### Changed +- Moved `app.py` and `infer.py` to [scripts/](scripts/) +- Improved edge removal. + +## 2025-03-18 +### Added +- Training and evaluation code. See [docs/train.md](docs/train.md) and [docs/eval.md](docs/eval.md). +- Supported installation via pip. Thanks to @fabiencastan and @jgoueslard + for commits in the [#47](https://github.com/microsoft/MoGe/pull/47) +- Supported command-line usage when installed. + +### Changed +- Moved `scripts/` into `moge/` for package installation and command-line usage. +- Renamed `moge.model.moge_model` to `moge.model.v1` for version management. + Now you can import the model class through `from moge.model.v1 import MoGeModel` or `from moge.model import import_model_class_by_version; MoGeModel = import_model_class_by_version('v1')`. +- Exposed `num_tokens` parameter in MoGe model. + +## 2025-06-10 +### Added +- Released MoGe-2. \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..f9ba8cf65f3e3104dd061c178066ec8247811f33 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3458b5ccd398afed340e17a4d0615c9a8666bb5d --- /dev/null +++ b/LICENSE @@ -0,0 +1,224 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cb5ad9cf7a3be8dda746f35e6bc2171859b81ad0 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: MoGe 2 +emoji: 🚀 +colorFrom: indigo +colorTo: purple +sdk: gradio +sdk_version: 5.33.0 +app_file: app.py +pinned: false +license: mit +short_description: Monocular metric-scale geometry estimation +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..b3c89efc852e22f71eabf5dfbc6ac62493425eb6 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). + + diff --git a/SUPPORT.md b/SUPPORT.md new file mode 100644 index 0000000000000000000000000000000000000000..291d4d43733f4c15a81ff598ec1c99fd6c18f64c --- /dev/null +++ b/SUPPORT.md @@ -0,0 +1,25 @@ +# TODO: The maintainer of this repo has not yet edited this file + +**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? + +- **No CSS support:** Fill out this template with information about how to file issues and get help. +- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. +- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. + +*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* + +# Support + +## How to file issues and get help + +This project uses GitHub Issues to track bugs and feature requests. Please search the existing +issues before filing new issues to avoid duplicates. For new issues, file your bug or +feature request as a new Issue. + +For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE +FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER +CHANNEL. WHERE WILL YOU HELP PEOPLE?**. + +## Microsoft Support Policy + +Support for this **PROJECT or PRODUCT** is limited to the resources listed above. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5eb6fd671e12efbc19c12ab3997bdcf205f7e8 --- /dev/null +++ b/app.py @@ -0,0 +1,298 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +import sys +from pathlib import Path +import time +import uuid +import tempfile +import itertools +from typing import * +import atexit +from concurrent.futures import ThreadPoolExecutor +import shutil + +import click + + +@click.command(help='Web demo') +@click.option('--share', is_flag=True, help='Whether to run the app in shared mode.') +@click.option('--pretrained', 'pretrained_model_name_or_path', default=None, help='The name or path of the pre-trained model.') +@click.option('--version', 'model_version', default='v2', help='The version of the model.') +def main(share: bool, pretrained_model_name_or_path: str, model_version: str, use_fp16: bool = True): + print("Import modules...") + # Lazy import + import cv2 + import torch + import numpy as np + import trimesh + import trimesh.visual + from PIL import Image + import gradio as gr + try: + import spaces # This is for deployment at huggingface.co/spaces + HUGGINFACE_SPACES_INSTALLED = True + except ImportError: + HUGGINFACE_SPACES_INSTALLED = False + + import utils3d + from moge.utils.io import write_normal + from moge.utils.vis import colorize_depth, colorize_normal + from moge.model import import_model_class_by_version + from moge.utils.geometry_numpy import depth_occlusion_edge_numpy + from moge.utils.tools import timeit + + print("Load model...") + if pretrained_model_name_or_path is None: + DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = { + "v1": "Ruicheng/moge-vitl", + "v2": "Ruicheng/moge-2-vitl-normal", + } + pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version] + model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval() + if use_fp16: + model.half() + thread_pool_executor = ThreadPoolExecutor(max_workers=1) + + def delete_later(path: Union[str, os.PathLike], delay: int = 300): + def _delete(): + try: + os.remove(path) + except FileNotFoundError: + pass + def _wait_and_delete(): + time.sleep(delay) + _delete(path) + thread_pool_executor.submit(_wait_and_delete) + atexit.register(_delete) + + # Inference on GPU. + @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else lambda x: x) + def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]: + image_tensor = torch.tensor(image, dtype=torch.float32 if not use_fp16 else torch.float16, device=torch.device('cuda')).permute(2, 0, 1) / 255 + output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=use_fp16) + output = {k: v.cpu().numpy() for k, v in output.items()} + return output + + # Full inference pipeline + def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None): + larger_size = max(image.shape[:2]) + if larger_size > max_size: + scale = max_size / larger_size + image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + + height, width = image.shape[:2] + + resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 18}.get(resolution_level, 9) + output = run_with_gpu(image, resolution_level_int, apply_mask) + + points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None) + + if remove_edge: + mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=0.04) + else: + mask_cleaned = mask + + results = { + **output, + 'mask_cleaned': mask_cleaned, + 'image': image + } + + # depth & normal visualization + depth_vis = colorize_depth(depth) + if normal is not None: + normal_vis = colorize_normal(normal) + else: + normal_vis = gr.update(label="Normal map (not avalable for this model)") + + # mesh & pointcloud + if normal is None: + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask_cleaned, + tri=True + ) + vertex_normals = None + else: + faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + normal, + mask=mask_cleaned, + tri=True + ) + vertices = vertices * np.array([1, -1, -1], dtype=np.float32) + vertex_uvs = vertex_uvs * np.array([1, -1], dtype=np.float32) + np.array([0, 1], dtype=np.float32) + if vertex_normals is not None: + vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32) + + tempdir = Path(tempfile.gettempdir(), 'moge') + tempdir.mkdir(exist_ok=True) + output_path = Path(tempdir, request.session_hash) + shutil.rmtree(output_path, ignore_errors=True) + output_path.mkdir(exist_ok=True, parents=True) + trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_normals=vertex_normals, + visual = trimesh.visual.texture.TextureVisuals( + uv=vertex_uvs, + material=trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(image), + metallicFactor=0.5, + roughnessFactor=1.0 + ) + ), + process=False + ).export(output_path / 'mesh.glb') + pointcloud = trimesh.PointCloud( + vertices=vertices, + colors=vertex_colors, + ) + pointcloud.vertex_normals = vertex_normals + pointcloud.export(output_path / 'pointcloud.ply', vertex_normal=True) + trimesh.PointCloud( + vertices=vertices, + colors=vertex_colors, + ).export(output_path / 'pointcloud.glb', include_normals=True) + cv2.imwrite(str(output_path /'mask.png'), mask.astype(np.uint8) * 255) + cv2.imwrite(str(output_path / 'depth.exr'), depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(output_path / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + if normal is not None: + cv2.imwrite(str(output_path / 'normal.exr'), cv2.cvtColor(normal.astype(np.float32) * np.array([1, -1, -1], dtype=np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + + files = ['mesh.glb', 'pointcloud.ply', 'depth.exr', 'points.exr', 'mask.png'] + if normal is not None: + files.append('normal.exr') + + for f in files: + delete_later(output_path / f) + + # FOV + intrinsics = results['intrinsics'] + fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) + fov_x, fov_y = np.rad2deg([fov_x, fov_y]) + + # messages + viewer_message = f'**Note:** Inference has been completed. It may take a few seconds to download the 3D model.' + if resolution_level != 'Ultra': + depth_message = f'**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.' + else: + depth_message = "" + + return ( + results, + depth_vis, + normal_vis, + output_path / 'pointcloud.glb', + [(output_path / f).as_posix() for f in files if (output_path / f).exists()], + f'- **Horizontal FOV: {fov_x:.1f}°**. \n - **Vertical FOV: {fov_y:.1f}°**', + viewer_message, + depth_message + ) + + def reset_measure(results: Dict[str, np.ndarray]): + return [results['image'], [], ""] + + + def measure(results: Dict[str, np.ndarray], measure_points: List[Tuple[int, int]], event: gr.SelectData): + point2d = event.index[0], event.index[1] + measure_points.append(point2d) + + image = results['image'].copy() + for p in measure_points: + image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) + + depth_text = "" + for i, p in enumerate(measure_points): + d = results['depth'][p[1], p[0]] + depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" + + if len(measure_points) == 2: + point1, point2 = measure_points + image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) + distance = np.linalg.norm(results['points'][point1[1], point1[0]] - results['points'][point2[1], point2[0]]) + measure_points = [] + + distance_text = f"- **Distance: {distance:.2f}m**" + + text = depth_text + distance_text + return [image, measure_points, text] + else: + return [image, measure_points, depth_text] + + print("Create Gradio app...") + with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown( +f''' +
+

Turn a 2D image into 3D with MoGe badge-github-stars

+
+''') + results = gr.State(value=None) + measure_points = gr.State(value=[]) + + with gr.Row(): + with gr.Column(): + input_image = gr.Image(type="numpy", image_mode="RGB", label="Input Image") + with gr.Accordion(label="Settings", open=False): + max_size_input = gr.Number(value=800, label="Maximum Image Size", precision=0, minimum=256, maximum=2048) + resolution_level = gr.Dropdown(['Low', 'Medium', 'High', 'Ultra'], label="Inference Resolution Level", value='High') + apply_mask = gr.Checkbox(value=True, label="Apply mask") + remove_edges = gr.Checkbox(value=True, label="Remove edges") + submit_btn = gr.Button("Submit", variant='primary') + + with gr.Column(): + with gr.Tabs(): + with gr.Tab("3D View"): + viewer_message = gr.Markdown("") + model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1.0, 1.0, 1.0, 1.0], height="60vh") + fov = gr.Markdown() + with gr.Tab("Depth"): + depth_message = gr.Markdown("") + depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format='png', interactive=False) + with gr.Tab("Normal", interactive=hasattr(model, 'normal_head')): + normal_map = gr.Image(type="numpy", label="Normal Map", format='png', interactive=False) + with gr.Tab("Measure", interactive=hasattr(model, 'scale_head')): + gr.Markdown("### Click on the image to measure the distance between two points. \n" + "**Note:** Metric scale is most reliable for typical indoor or street scenes, and may degrade for contents unfamiliar to the model (e.g., stylized or close-up images).") + measure_image = gr.Image(type="numpy", show_label=False, format='webp', interactive=False, sources=[]) + measure_text = gr.Markdown("") + with gr.Tab("Download"): + files = gr.File(type='filepath', label="Output Files") + + if Path('example_images').exists(): + example_image_paths = sorted(list(itertools.chain(*[Path('example_images').glob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']]))) + examples = gr.Examples( + examples = example_image_paths, + inputs=input_image, + label="Examples" + ) + + submit_btn.click( + fn=lambda: [None, None, None, None, None, "", "", ""], + outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message] + ).then( + fn=run, + inputs=[input_image, max_size_input, resolution_level, apply_mask, remove_edges], + outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message] + ).then( + fn=reset_measure, + inputs=[results], + outputs=[measure_image, measure_points, measure_text] + ) + + measure_image.select( + fn=measure, + inputs=[results, measure_points], + outputs=[measure_image, measure_points, measure_text] + ) + + demo.launch(share=share) + + +if __name__ == '__main__': + main() diff --git a/assets/overview_simplified.png b/assets/overview_simplified.png new file mode 100644 index 0000000000000000000000000000000000000000..60a958eb46578b30a14fec1cfaea1289df88391e --- /dev/null +++ b/assets/overview_simplified.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7025a671e863bddbc22e79dc3e2eca8b7aeaf35fe93f6ef7f2b18f4fc9e093e6 +size 414314 diff --git a/assets/panorama_pipeline.png b/assets/panorama_pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..334354c8a68ed7a9865c424f9890a72468b0a198 --- /dev/null +++ b/assets/panorama_pipeline.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed28c5309162bddda016ca600307ecc73f7e6415f9eaaefb9f6fffadf6951aaa +size 738233 diff --git a/baselines/da_v2.py b/baselines/da_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..bca560a75514bdfa38c9a28c8d36ea0e006dab1e --- /dev/null +++ b/baselines/da_v2.py @@ -0,0 +1,88 @@ +# Reference: https://github.com/DepthAnything/Depth-Anything-V2 +import os +import sys +from typing import * +from pathlib import Path + +import click +import torch +import torch.nn.functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +from moge.test.baseline import MGEBaselineInterface + + +class Baseline(MGEBaselineInterface): + def __init__(self, repo_path: str, backbone: str, num_tokens: int, device: Union[torch.device, str]): + # Create from repo + repo_path = os.path.abspath(repo_path) + if repo_path not in sys.path: + sys.path.append(repo_path) + if not Path(repo_path).exists(): + raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.') + from depth_anything_v2.dpt import DepthAnythingV2 + + device = torch.device(device) + + # Instantiate model + model = DepthAnythingV2(encoder=backbone, features=256, out_channels=[256, 512, 1024, 1024]) + + # Load checkpoint + checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_{backbone}.pth') + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + model.load_state_dict(checkpoint) + + model.to(device).eval() + self.model = model + self.num_tokens = num_tokens + self.device = device + + @click.command() + @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.') + @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Encoder architecture.') + @click.option('--num_tokens', type=int, default=None, help='Number of tokens to use for the input image.') + @click.option('--device', type=str, default='cuda', help='Device to use for inference.') + @staticmethod + def load(repo_path: str, backbone, num_tokens: int, device: torch.device = 'cuda'): + return Baseline(repo_path, backbone, num_tokens, device) + + @torch.inference_mode() + def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + original_height, original_width = image.shape[-2:] + + assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input" + + if image.ndim == 3: + image = image.unsqueeze(0) + omit_batch_dim = True + else: + omit_batch_dim = False + + if self.num_tokens is None: + resize_factor = 518 / min(original_height, original_width) + expected_width = round(original_width * resize_factor / 14) * 14 + expected_height = round(original_height * resize_factor / 14) * 14 + else: + aspect_ratio = original_width / original_height + tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5) + tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5) + expected_width = tokens_cols * 14 + expected_height = tokens_rows * 14 + image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True) + + image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + disparity = self.model(image) + + disparity = F.interpolate(disparity[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0] + + if omit_batch_dim: + disparity = disparity.squeeze(0) + + return { + 'disparity_affine_invariant': disparity + } + diff --git a/baselines/da_v2_metric.py b/baselines/da_v2_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4c70d8c6634babf165d2982a692230f5adeac6 --- /dev/null +++ b/baselines/da_v2_metric.py @@ -0,0 +1,99 @@ +# Reference https://github.com/DepthAnything/Depth-Anything-V2/metric_depth +import os +import sys +from typing import * +from pathlib import Path + +import click +import torch +import torch.nn.functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as TF +import cv2 + +from moge.test.baseline import MGEBaselineInterface + + +class Baseline(MGEBaselineInterface): + + def __init__(self, repo_path: str, backbone: str, domain: str, num_tokens: int, device: str): + device = torch.device(device) + repo_path = os.path.abspath(repo_path) + if not Path(repo_path).exists(): + raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.') + sys.path.append(os.path.join(repo_path, 'metric_depth')) + from depth_anything_v2.dpt import DepthAnythingV2 + + model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]} + } + + if domain == 'indoor': + dataset = 'hypersim' + max_depth = 20 + elif domain == 'outdoor': + dataset = 'vkitti' + max_depth = 80 + else: + raise ValueError(f"Invalid domain: {domain}") + + model = DepthAnythingV2(**model_configs[backbone], max_depth=max_depth) + checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_metric_{dataset}_{backbone}.pth') + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.') + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True)) + model.eval().to(device) + + self.model = model + self.num_tokens = num_tokens + self.device = device + + @click.command() + @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.') + @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Backbone architecture.') + @click.option('--domain', type=click.Choice(['indoor', 'outdoor']), help='Domain of the dataset.') + @click.option('--num_tokens', type=int, default=None, help='Number of tokens for the ViT model') + @click.option('--device', type=str, default='cuda', help='Device to use for inference.') + @staticmethod + def load(repo_path: str, backbone: str, domain: str, num_tokens: int, device: str): + return Baseline(repo_path, backbone, domain, num_tokens, device) + + @torch.inference_mode() + def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + original_height, original_width = image.shape[-2:] + + assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input" + + if image.ndim == 3: + image = image.unsqueeze(0) + omit_batch_dim = True + else: + omit_batch_dim = False + + if self.num_tokens is None: + resize_factor = 518 / min(original_height, original_width) + expected_width = round(original_width * resize_factor / 14) * 14 + expected_height = round(original_height * resize_factor / 14) * 14 + else: + aspect_ratio = original_width / original_height + tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5) + tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5) + expected_width = tokens_cols * 14 + expected_height = tokens_rows * 14 + image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True) + + image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + depth = self.model(image) + + depth = F.interpolate(depth[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0] + + if omit_batch_dim: + depth = depth.squeeze(0) + + return { + 'depth_metric': depth + } + diff --git a/baselines/metric3d_v2.py b/baselines/metric3d_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..661ed5ddfd3bf6f34f53dcb19bcc5a9889a8e377 --- /dev/null +++ b/baselines/metric3d_v2.py @@ -0,0 +1,117 @@ +# Reference: https://github.com/YvanYin/Metric3D +import os +import sys +from typing import * + +import click +import torch +import torch.nn.functional as F +import cv2 + +from moge.test.baseline import MGEBaselineInterface + + +class Baseline(MGEBaselineInterface): + def __init__(self, backbone: Literal['vits', 'vitl', 'vitg'], device): + backbone_map = { + 'vits': 'metric3d_vit_small', + 'vitl': 'metric3d_vit_large', + 'vitg': 'metric3d_vit_giant2' + } + + device = torch.device(device) + model = torch.hub.load('yvanyin/metric3d', backbone_map[backbone], pretrain=True) + model.to(device).eval() + + self.model = model + self.device = device + + @click.command() + @click.option('--backbone', type=click.Choice(['vits', 'vitl', 'vitg']), default='vitl', help='Encoder architecture.') + @click.option('--device', type=str, default='cuda', help='Device to use.') + @staticmethod + def load(backbone: str = 'vitl', device: torch.device = 'cuda'): + return Baseline(backbone, device) + + @torch.inference_mode() + def inference_one_image(self, image: torch.Tensor, intrinsics: torch.Tensor = None): + # Reference: https://github.com/YvanYin/Metric3D/blob/main/mono/utils/do_test.py + + # rgb_origin: RGB, 0-255, uint8 + rgb_origin = image.cpu().numpy().transpose((1, 2, 0)) * 255 + + # keep ratio resize + input_size = (616, 1064) # for vit model + h, w = rgb_origin.shape[:2] + scale = min(input_size[0] / h, input_size[1] / w) + rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) + if intrinsics is not None: + focal = intrinsics[0, 0] * int(w * scale) + + # padding to input_size + padding = [123.675, 116.28, 103.53] + h, w = rgb.shape[:2] + pad_h = input_size[0] - h + pad_w = input_size[1] - w + pad_h_half = pad_h // 2 + pad_w_half = pad_w // 2 + rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=padding) + pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] + + # normalize rgb + mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] + std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] + rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() + rgb = torch.div((rgb - mean), std) + rgb = rgb[None, :, :, :].cuda() + + # inference + pred_depth, confidence, output_dict = self.model.inference({'input': rgb}) + + # un pad + pred_depth = pred_depth.squeeze() + pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]] + pred_depth = pred_depth.clamp_min(0.5) # clamp to 0.5m, since metric3d could yield very small depth values, resulting in crashed the scale shift alignment. + + # upsample to original size + pred_depth = F.interpolate(pred_depth[None, None, :, :], image.shape[-2:], mode='bilinear').squeeze() + + if intrinsics is not None: + # de-canonical transform + canonical_to_real_scale = focal / 1000.0 # 1000.0 is the focal length of canonical camera + pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric + pred_depth = torch.clamp(pred_depth, 0, 300) + + pred_normal, normal_confidence = output_dict['prediction_normal'].split([3, 1], dim=1) # see https://arxiv.org/abs/2109.09881 for details + + # un pad and resize to some size if needed + pred_normal = pred_normal.squeeze(0) + pred_normal = pred_normal[:, pad_info[0] : pred_normal.shape[1] - pad_info[1], pad_info[2] : pred_normal.shape[2] - pad_info[3]] + + # you can now do anything with the normal + pred_normal = F.interpolate(pred_normal[None, :, :, :], image.shape[-2:], mode='bilinear').squeeze(0) + pred_normal = F.normalize(pred_normal, p=2, dim=0) + + return pred_depth, pred_normal.permute(1, 2, 0) + + @torch.inference_mode() + def infer(self, image: torch.Tensor, intrinsics: torch.Tensor = None): + # image: (B, H, W, 3) or (H, W, 3) + if image.ndim == 3: + pred_depth, pred_normal = self.inference_one_image(image, intrinsics) + else: + for i in range(image.shape[0]): + pred_depth_i, pred_normal_i = self.inference_one_image(image[i], intrinsics[i] if intrinsics is not None else None) + pred_depth.append(pred_depth_i) + pred_normal.append(pred_normal_i) + pred_depth = torch.stack(pred_depth, dim=0) + pred_normal = torch.stack(pred_normal, dim=0) + + if intrinsics is not None: + return { + "depth_metric": pred_depth, + } + else: + return { + "depth_scale_invariant": pred_depth, + } diff --git a/baselines/moge.py b/baselines/moge.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdfaae24d01649230b107503275318971888409 --- /dev/null +++ b/baselines/moge.py @@ -0,0 +1,83 @@ +import os +import sys +from typing import * +import importlib + +import click +import torch +import utils3d + +from moge.test.baseline import MGEBaselineInterface + + +class Baseline(MGEBaselineInterface): + + def __init__(self, num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'): + super().__init__() + from moge.model import import_model_class_by_version + MoGeModel = import_model_class_by_version(version) + self.version = version + + self.model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() + + self.device = torch.device(device) + self.num_tokens = num_tokens + self.resolution_level = resolution_level + self.use_fp16 = use_fp16 + + @click.command() + @click.option('--num_tokens', type=int, default=None) + @click.option('--resolution_level', type=int, default=9) + @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl') + @click.option('--fp16', 'use_fp16', is_flag=True) + @click.option('--device', type=str, default='cuda:0') + @click.option('--version', type=str, default='v1') + @staticmethod + def load(num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'): + return Baseline(num_tokens, resolution_level, pretrained_model_name_or_path, use_fp16, device, version) + + # Implementation for inference + @torch.inference_mode() + def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.FloatTensor] = None): + if intrinsics is not None: + fov_x, _ = utils3d.torch.intrinsics_to_fov(intrinsics) + fov_x = torch.rad2deg(fov_x) + else: + fov_x = None + output = self.model.infer(image, fov_x=fov_x, apply_mask=True, num_tokens=self.num_tokens) + + if self.version == 'v1': + return { + 'points_scale_invariant': output['points'], + 'depth_scale_invariant': output['depth'], + 'intrinsics': output['intrinsics'], + } + else: + return { + 'points_metric': output['points'], + 'depth_metric': output['depth'], + 'intrinsics': output['intrinsics'], + } + + @torch.inference_mode() + def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: torch.FloatTensor = None): + if intrinsics is not None: + fov_x, _ = utils3d.torch.intrinsics_to_fov(intrinsics) + fov_x = torch.rad2deg(fov_x) + else: + fov_x = None + output = self.model.infer(image, fov_x=fov_x, apply_mask=False, num_tokens=self.num_tokens, use_fp16=self.use_fp16) + + if self.version == 'v1': + return { + 'points_scale_invariant': output['points'], + 'depth_scale_invariant': output['depth'], + 'intrinsics': output['intrinsics'], + } + else: + return { + 'points_metric': output['points'], + 'depth_metric': output['depth'], + 'intrinsics': output['intrinsics'], + } + diff --git a/configs/eval/all_benchmarks.json b/configs/eval/all_benchmarks.json new file mode 100644 index 0000000000000000000000000000000000000000..94c0fc4605f3a3472d7d39d4d8e40eb9e3d784b7 --- /dev/null +++ b/configs/eval/all_benchmarks.json @@ -0,0 +1,78 @@ +{ + "NYUv2": { + "path": "data/eval/NYUv2", + "width": 640, + "height": 480, + "split": ".index.txt", + "depth_unit": 1.0 + }, + "KITTI": { + "path": "data/eval/KITTI", + "width": 750, + "height": 375, + "split": ".index.txt", + "depth_unit": 1 + }, + "ETH3D": { + "path": "data/eval/ETH3D", + "width": 2048, + "height": 1365, + "split": ".index.txt", + "include_segmentation": true, + "depth_unit": 1 + }, + "iBims-1": { + "path": "data/eval/iBims-1", + "width": 640, + "height": 480, + "split": ".index.txt", + "has_sharp_boundary": true, + "include_segmentation": true, + "depth_unit": 1.0 + }, + "GSO": { + "path": "data/eval/GSO", + "width": 512, + "height": 512, + "split": ".index.txt" + }, + "Sintel": { + "path": "data/eval/Sintel", + "width": 872, + "height": 436, + "split": ".index.txt", + "has_sharp_boundary": true, + "include_segmentation": true + }, + "DDAD": { + "path": "data/eval/DDAD", + "width": 1400, + "height": 700, + "include_segmentation": true, + "split": ".index.txt", + "depth_unit": 1.0 + }, + "DIODE": { + "path": "data/eval/DIODE", + "width": 1024, + "height": 768, + "split": ".index.txt", + "include_segmentation": true, + "depth_unit": 1.0 + }, + "Spring": { + "path": "data/eval/Spring", + "width": 1920, + "height": 1080, + "split": ".index.txt", + "has_sharp_boundary": true + }, + "HAMMER": { + "path": "data/eval/HAMMER", + "width": 1664, + "height": 832, + "split": ".index.txt", + "depth_unit": 1, + "has_sharp_boundary": true + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/ddad.json b/configs/eval/benchmarks/ddad.json new file mode 100644 index 0000000000000000000000000000000000000000..09dd4d74bbccbb46a4013afd9fee1e717d606a53 --- /dev/null +++ b/configs/eval/benchmarks/ddad.json @@ -0,0 +1,9 @@ +{ + "DDAD": { + "path": "data/eval/DDAD", + "width": 1400, + "height": 700, + "include_segmentation": true, + "split": ".index.txt" + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/diode.json b/configs/eval/benchmarks/diode.json new file mode 100644 index 0000000000000000000000000000000000000000..679ca6ee13ddf5e5bcab93f453b2f11279781a2f --- /dev/null +++ b/configs/eval/benchmarks/diode.json @@ -0,0 +1,9 @@ +{ + "DIODE": { + "path": "data/eval/DIODE", + "width": 1024, + "height": 768, + "split": ".index.txt", + "include_segmentation": true + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/eth3d.json b/configs/eval/benchmarks/eth3d.json new file mode 100644 index 0000000000000000000000000000000000000000..88a3a1b291dcde3f2959c0d36d7ebbc33213fc84 --- /dev/null +++ b/configs/eval/benchmarks/eth3d.json @@ -0,0 +1,10 @@ +{ + "ETH3D": { + "path": "data/eval/ETH3D", + "width": 2048, + "height": 1365, + "split": ".index.txt", + "include_segmentation": true, + "depth_unit": 1 + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/gso.json b/configs/eval/benchmarks/gso.json new file mode 100644 index 0000000000000000000000000000000000000000..ee1aefff7ae3453b0cdddf7ab3369301d2e8d924 --- /dev/null +++ b/configs/eval/benchmarks/gso.json @@ -0,0 +1,8 @@ +{ + "GSO": { + "path": "data/eval/GSO", + "width": 512, + "height": 512, + "split": ".index.txt" + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/hammer.json b/configs/eval/benchmarks/hammer.json new file mode 100644 index 0000000000000000000000000000000000000000..41838db6bfcf2ea6f3ed230b6c7ee3315ec3fbfe --- /dev/null +++ b/configs/eval/benchmarks/hammer.json @@ -0,0 +1,10 @@ +{ + "HAMMER": { + "path": "data/eval/HAMMER", + "width": 1664, + "height": 832, + "split": ".index.txt", + "depth_unit": 1, + "has_sharp_boundary": true + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/ibims-1.json b/configs/eval/benchmarks/ibims-1.json new file mode 100644 index 0000000000000000000000000000000000000000..a6f0a0387891deb09bcae61bc4e4098e04db7307 --- /dev/null +++ b/configs/eval/benchmarks/ibims-1.json @@ -0,0 +1,10 @@ +{ + "iBims-1": { + "path": "data/eval/iBims-1", + "width": 640, + "height": 480, + "split": ".index.txt", + "include_segmentation": true, + "has_sharp_boundary": true + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/kitti.json b/configs/eval/benchmarks/kitti.json new file mode 100644 index 0000000000000000000000000000000000000000..10ca7c3eb560649ce25edbf4ed5c835e90396cb8 --- /dev/null +++ b/configs/eval/benchmarks/kitti.json @@ -0,0 +1,9 @@ +{ + "KITTI": { + "path": "data/eval/KITTI", + "width": 750, + "height": 375, + "split": ".index.txt", + "depth_unit": 1 + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/nyu.json b/configs/eval/benchmarks/nyu.json new file mode 100644 index 0000000000000000000000000000000000000000..62841335b17f508ca903634b51b70f3e8a576186 --- /dev/null +++ b/configs/eval/benchmarks/nyu.json @@ -0,0 +1,8 @@ +{ + "NYUv2": { + "path": "data/eval/NYUv2", + "width": 640, + "height": 480, + "split": ".test.txt" + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/sintel.json b/configs/eval/benchmarks/sintel.json new file mode 100644 index 0000000000000000000000000000000000000000..fde872e282e260f987208168bdcd166a104732d3 --- /dev/null +++ b/configs/eval/benchmarks/sintel.json @@ -0,0 +1,10 @@ +{ + "Sintel": { + "path": "data/eval/Sintel", + "width": 872, + "height": 436, + "split": ".index.txt", + "include_segmentation": true, + "has_sharp_boundary": true + } +} \ No newline at end of file diff --git a/configs/eval/benchmarks/spring.json b/configs/eval/benchmarks/spring.json new file mode 100644 index 0000000000000000000000000000000000000000..a18e51a969fe5b605c03ed1f0a4714dec9379539 --- /dev/null +++ b/configs/eval/benchmarks/spring.json @@ -0,0 +1,9 @@ +{ + "Spring": { + "path": "data/eval/Spring", + "width": 1920, + "height": 1080, + "split": ".test.txt", + "has_sharp_boundary": true + } +} \ No newline at end of file diff --git a/configs/train/v1.json b/configs/train/v1.json new file mode 100644 index 0000000000000000000000000000000000000000..f87f38944129e60da58eed20395d36f4404d7164 --- /dev/null +++ b/configs/train/v1.json @@ -0,0 +1,77 @@ +{ + "data": { + "aspect_ratio_range": [0.5, 2.0], + "area_range": [250000, 1000000], + "clamp_max_depth": 1000.0, + "center_augmentation": 0.5, + "fov_range_absolute": [1, 179], + "fov_range_relative": [0.01, 1.0], + "image_augmentation": ["jittering", "jpeg_loss", "blurring"], + "datasets": [ + { + "name": "TartanAir", + "path": "blobmnt/data_v3/TartanAir", + "label_type": "synthetic", + "index": ".index.txt", + "depth": "depth.png", + "weight": 4.8, + "center_augmentation": 0.25, + "fov_range_absolute": [30, 150], + "fov_range_relative": [0.5, 1.0], + "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"] + } + ] + }, + "model_version": "v1", + "model": { + "encoder": "dinov2_vitl14", + "remap_output": "exp", + "intermediate_layers": 4, + "dim_upsample": [256, 128, 64], + "dim_times_res_block_hidden": 2, + "num_res_blocks": 2, + "num_tokens_range": [1200, 2500], + "last_conv_channels": 32, + "last_conv_size": 1 + }, + "optimizer": { + "type": "AdamW", + "params": [ + {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4}, + {"params": {"include": ["*backbone.*"]}, "lr": 1e-5} + ] + }, + "lr_scheduler": { + "type": "SequentialLR", + "params": { + "schedulers": [ + {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}}, + {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}} + ], + "milestones": [2000] + } + }, + "low_resolution_training_steps": 50000, + "loss": { + "invalid": {}, + "synthetic": { + "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, + "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, + "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, + "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}}, + "normal": {"function": "normal_loss", "weight": 1.0}, + "mask": {"function": "mask_l2_loss", "weight": 1.0} + }, + "sfm": { + "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, + "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, + "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, + "mask": {"function": "mask_l2_loss", "weight": 1.0} + }, + "lidar": { + "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, + "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, + "mask": {"function": "mask_l2_loss", "weight": 1.0} + } + } +} \ No newline at end of file diff --git a/docs/eval.md b/docs/eval.md new file mode 100644 index 0000000000000000000000000000000000000000..a9d93e4a540c6df1c06aaa5694c8377e67ba468f --- /dev/null +++ b/docs/eval.md @@ -0,0 +1,77 @@ +# Evaluation + +We provide a unified evaluation script that runs baselines on multiple benchmarks. It takes a baseline model and evaluation configurations, evaluates on-the-fly, and reports results instantly in a JSON file. + +## Benchmarks + +Donwload the processed datasets from [Huggingface Datasets](https://huggingface.co/datasets/Ruicheng/monocular-geometry-evaluation) and put them in the `data/eval` directory, using `huggingface-cli`: + +```bash +mkdir -p data/eval +huggingface-cli download Ruicheng/monocular-geometry-evaluation --repo-type dataset --local-dir data/eval --local-dir-use-symlinks False +``` + +Then unzip the downloaded files: + +```bash +cd data/eval +unzip '*.zip' +# rm *.zip # if you don't keep the zip files +``` + +## Configuration + +See [`configs/eval/all_benchmarks.json`](../configs/eval/all_benchmarks.json) for an example of evaluation configurations on all benchmarks. You can modify this file to evaluate on different benchmarks or different baselines. + +## Baseline + +Some examples of baselines are provided in [`baselines/`](../baselines/). Pass the path to the baseline model python code to the `--baseline` argument of the evaluation script. + +## Run Evaluation + +Run the script [`moge/scripts/eval_baseline.py`](../moge/scripts/eval_baseline.py). +For example, + +```bash +# Evaluate MoGe on the 10 benchmarks +python moge/scripts/eval_baseline.py --baseline baselines/moge.py --config configs/eval/all_benchmarks.json --output eval_output/moge.json --pretrained Ruicheng/moge-vitl --resolution_level 9 + +# Evaluate Depth Anything V2 on the 10 benchmarks. (NOTE: affine disparity) +python moge/scripts/eval_baseline.py --baseline baselines/da_v2.py --config configs/eval/all_benchmarks.json --output eval_output/da_v2.json +``` + +The `--baselies` `--input` `--output` arguments are for the inference script. The rest arguments, e.g. `--pretrained` `--resolution_level`, are custormized for loading the baseline model. + +Details of the arguments: + +``` +Usage: eval_baseline.py [OPTIONS] + + Evaluation script. + +Options: + --baseline PATH Path to the baseline model python code. + --config PATH Path to the evaluation configurations. Defaults to + "configs/eval/all_benchmarks.json". + --output PATH Path to the output json file. + --oracle Use oracle mode for evaluation, i.e., use the GT intrinsics + input. + --dump_pred Dump predition results. + --dump_gt Dump ground truth. + --help Show this message and exit. +``` + + + +## Wrap a Customized Baseline + +Wrap any baseline method with [`moge.test.baseline.MGEBaselineInterface`](../moge/test/baseline.py). +See [`baselines/`](../baselines/) for more examples. + +It is a good idea to check the correctness of the baseline implementation by running inference on a small set of images via [`moge/scripts/infer_baselines.py`](../moge/scripts/infer_baselines.py): + +```base +python moge/scripts/infer_baselines.py --baseline baselines/moge.py --input example_images/ --output infer_outupt/moge --pretrained Ruicheng/moge-vitl --maps --ply +``` + + diff --git a/docs/train.md b/docs/train.md new file mode 100644 index 0000000000000000000000000000000000000000..170abb80e08ac5eb25badedc2b05138c21bb33f2 --- /dev/null +++ b/docs/train.md @@ -0,0 +1,181 @@ + +# Training + +This document provides instructions for training and finetuning the MoGe model. + +## Additional Requirements + +The following packages other than those listed in [`pyproject.toml`](../pyproject.toml) are required for training and finetuning the MoGe model: + +``` +accelerate +sympy +mlflow +``` + +## Data preparation + +### Dataset format + +Each dataset should be organized as follows: + +``` +somedataset +├── .index.txt # A list of instance paths +├── folder1 +│ ├── instance1 # Each instance is in a folder +│ │ ├── image.jpg # RGB image. +│ │ ├── depth.png # 16-bit depth. See moge/utils/io.py for details +│ │ ├── meta.json # Stores "intrinsics" as a 3x3 matrix +│ │ └── ... # Other componests such as segmentation mask, normal map etc. +... +``` + +* `.index.txt` is placed at top directory to store a list of instance paths in this dataset. The dataloader will look for instances in this list. You may also use a custom split, e.g. `.train.txt`, `.val.txt` and specify it in the configuration file. + +* For depth images, it is recommended to use `read_depth()` and `write_depth()` in [`moge/utils/io.py`](../moge/utils/io.py) to read and write depth images. The depth is stored in logarithmic scale in 16-bit PNG format, offering a balanced precision, dynamic range and compression ratio compared to 16-bit and 32-bit EXR and linear depth formats. It also encodes `NaN` and `Inf` values for invalid depth values. + +* The `meta.json` should be a dictionary containing the key `intrinsics`, which are **normalized** camera parameters. You may put more metadata. + +* We also support reading and storing segementation masks for evaluation data (see paper evaluation of local points), which are saved in PNG format with semantic labels stored in png metadata as JSON strings. See `read_segmentation()` and `write_segmentation()` in [`moge/utils/io.py`](../moge/utils/io.py) for details. + + +### Visual inspection + +We provide a script to visualize the data and check the data quality. It will export the instance as a PLY file for visualization of point cloud. + +```bash +python moge/scripts/vis_data.py PATH_TO_INSTANCE --ply [-o SOMEWHERE_ELSE_TO_SAVE_VIS] +``` + +### DataLoader + +Our training dataloaders is customized to handle loading data, performing perspective crop, and augmentation in a multithreading pipeline. Please refer to [`moge/train/dataloader.py`](../moge/train/dataloader.py) if you have any concern. + + +## Configuration + +See [`configs/train/v1.json`](../configs/train/v1.json) for an example configuration file. The configuration file defines the hyperparameters for training the MoGe model. +Here is a commented configuration for reference: + +```json +{ + "data": { + "aspect_ratio_range": [0.5, 2.0], # Range of aspect ratio of sampled images + "area_range": [250000, 1000000], # Range of sampled image area in pixels + "clamp_max_depth": 1000.0, # Maximum far/near + "center_augmentation": 0.5, # Ratio of center crop augmentation + "fov_range_absolute": [1, 179], # Absolute range of FOV in degrees + "fov_range_relative": [0.01, 1.0], # Relative range of FOV to the original FOV + "image_augmentation": ["jittering", "jpeg_loss", "blurring"], # List of image augmentation techniques + "datasets": [ + { + "name": "TartanAir", # Name of the dataset. Name it as you like. + "path": "data/TartanAir", # Path to the dataset + "label_type": "synthetic", # Label type for this dataset. Losses will be applied accordingly. see "loss" config + "weight": 4.8, # Probability of sampling this dataset + "index": ".index.txt", # File name of the index file. Defaults to .index.txt + "depth": "depth.png", # File name of depth images. Defaults to depth.png + "center_augmentation": 0.25, # Below are dataset-specific hyperparameters. Overriding the global ones above. + "fov_range_absolute": [30, 150], + "fov_range_relative": [0.5, 1.0], + "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"] + } + ] + }, + "model_version": "v1", # Model version. If you have multiple model variants, you can use this to switch between them. + "model": { # Model hyperparameters. Will be passed to Model __init__() as kwargs. + "encoder": "dinov2_vitl14", + "remap_output": "exp", + "intermediate_layers": 4, + "dim_upsample": [256, 128, 64], + "dim_times_res_block_hidden": 2, + "num_res_blocks": 2, + "num_tokens_range": [1200, 2500], + "last_conv_channels": 32, + "last_conv_size": 1 + }, + "optimizer": { # Reflection-like optimizer configurations. See moge.train.utils.py build_optimizer() for details. + "type": "AdamW", + "params": [ + {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4}, + {"params": {"include": ["*backbone.*"]}, "lr": 1e-5} + ] + }, + "lr_scheduler": { # Reflection-like lr_scheduler configurations. See moge.train.utils.py build_lr_scheduler() for details. + "type": "SequentialLR", + "params": { + "schedulers": [ + {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}}, + {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}} + ], + "milestones": [2000] + } + }, + "low_resolution_training_steps": 50000, # Total number of low-resolution training steps. It makes the early stage training faster. Later stage training on varying size images will be slower. + "loss": { + "invalid": {}, # invalid instance due to runtime error when loading data + "synthetic": { # Below are loss hyperparameters + "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, + "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, + "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, + "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}}, + "normal": {"function": "normal_loss", "weight": 1.0}, + "mask": {"function": "mask_l2_loss", "weight": 1.0} + }, + "sfm": { + "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, + "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, + "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, + "mask": {"function": "mask_l2_loss", "weight": 1.0} + }, + "lidar": { + "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, + "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, + "mask": {"function": "mask_l2_loss", "weight": 1.0} + } + } +} +``` + +## Run Training + +Launch the training script [`moge/scripts/train.py`](../moge/scripts/train.py). Note that we use [`accelerate`](https://github.com/huggingface/accelerate) for distributed training. + +```bash +accelerate launch \ + --num_processes 8 \ + moge/scripts/train.py \ + --config configs/train/v1.json \ + --workspace workspace/debug \ + --gradient_accumulation_steps 2 \ + --batch_size_forward 2 \ + --checkpoint latest \ + --enable_gradient_checkpointing True \ + --vis_every 1000 \ + --enable_mlflow True +``` + + +## Finetuning + +To finetune the pre-trained MoGe model, download the model checkpoint and put it in a local directory, e.g. `pretrained/moge-vitl.pt`. + +> NOTE: when finetuning pretrained MoGe model, a much lower learning rate is required. +The suggested learning rate for finetuning is not greater than 1e-5 for the head and 1e-6 for the backbone. +And the batch size is recommended to be 32 at least. +The settings in default configuration are not optimal for specific datasets and may require further tuning. + +```bash +accelerate launch \ + --num_processes 8 \ + moge/scripts/train.py \ + --config configs/train/v1.json \ + --workspace workspace/debug \ + --gradient_accumulation_steps 2 \ + --batch_size_forward 2 \ + --checkpoint pretrained/moge-vitl.pt \ + --enable_gradient_checkpointing True \ + --vis_every 1000 \ + --enable_mlflow True +``` diff --git a/example_images/01_HouseIndoor.jpg b/example_images/01_HouseIndoor.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eee8b1f17491b5d5602a54b257e55fe3d09a3d20 --- /dev/null +++ b/example_images/01_HouseIndoor.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3eb519bc68d4262af0c68166ca69e786cac5f6656a1083f4c585c4a94005c859 +size 322353 diff --git a/example_images/02_Office.jpg b/example_images/02_Office.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a21eec3de0c64ed4a8ce9cc612145673882d07d --- /dev/null +++ b/example_images/02_Office.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28767640002f93b703b24a34a6d75ca24b1ef093a19f52ef0f9d3b074ef68c61 +size 197508 diff --git a/example_images/03_Traffic.jpg b/example_images/03_Traffic.jpg new file mode 100644 index 0000000000000000000000000000000000000000..457784f7e0371cdf2aa5b2d37dd959dbb3bc4c36 --- /dev/null +++ b/example_images/03_Traffic.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fa8b46849dd3de5b3b0a141d6aafe98e190f578ccec0c9dacc440cd8434db11 +size 1125098 diff --git a/example_images/04_BunnyCake.jpg b/example_images/04_BunnyCake.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7e40c8ab15499909c74d39bba12c210aae1be694 --- /dev/null +++ b/example_images/04_BunnyCake.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ddd187d91ebc2cf626bc51a26e1fc71d478237ce348732ae547f83655f05260 +size 69126 diff --git a/example_images/05_Mountain.jpg b/example_images/05_Mountain.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df9c2c8686c175cfce2273d8c0254485528399de --- /dev/null +++ b/example_images/05_Mountain.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:670d322f6588713f7d9c7349091de0aacb2a5b0b37c7b7433995e110fb2bcfbc +size 665958 diff --git a/example_images/06_MaitreyaBuddha.png b/example_images/06_MaitreyaBuddha.png new file mode 100644 index 0000000000000000000000000000000000000000..72193f4b66cb3d2f5583a6128bdcb5f10037d486 --- /dev/null +++ b/example_images/06_MaitreyaBuddha.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:396c5fd722bf5a21b931cbb70b883d6b1d5f9bab439cc426ec2f606fc2b7872d +size 1224680 diff --git a/example_images/07_Breads.jpg b/example_images/07_Breads.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0029d0c04179f2863f79a5429460122d41943560 --- /dev/null +++ b/example_images/07_Breads.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a95c2cab81412e252ee5a56a6100df31bb83de0f117607ca8476478f7f152a7b +size 156435 diff --git a/example_images/08_CatGirl.png b/example_images/08_CatGirl.png new file mode 100644 index 0000000000000000000000000000000000000000..664ef2a6bf02e1c1720f1ed19e00e57d2a839927 --- /dev/null +++ b/example_images/08_CatGirl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57fa6d587d598e7a428e8997b86d5c3a06e0e18529bfad8bab78ae03a1f5820f +size 1689759 diff --git a/example_images/09_Restaurant.jpg b/example_images/09_Restaurant.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87aa321a35339878b095c791e2f90aa49c0ba6be --- /dev/null +++ b/example_images/09_Restaurant.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2bb7b5a1e91a174101109b0976b8ae2a4d6bb7d6eadad6569106ed102d0d5a6 +size 794391 diff --git a/example_images/10_MedievalVillage.jpg b/example_images/10_MedievalVillage.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9eb958edb1b7a632bc91a4087acee52c8b557005 --- /dev/null +++ b/example_images/10_MedievalVillage.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:718ed1aeb1e0010194c5cf0e95371e6a29d45b84e93efbed63ff4cc60e74508b +size 465285 diff --git a/example_images/11_Room.jpg b/example_images/11_Room.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6093803db8e18d41e205d840c724e206e61d747 --- /dev/null +++ b/example_images/11_Room.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f34b99e89f3a57952bb88f11a6dc87e4a75423f55ad26748783c92854543cf5 +size 581651 diff --git a/example_images/12_StylizedHouses.jpg b/example_images/12_StylizedHouses.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16b7cca5ce92debff73bbd3bfb1d01e2da816856 --- /dev/null +++ b/example_images/12_StylizedHouses.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18120b27ea499ef9c921a5a02e987c687327896c7bb649a9703682737d25a6b8 +size 1243499 diff --git a/example_images/panorama/Braunschweig_Panoram.jpg b/example_images/panorama/Braunschweig_Panoram.jpg new file mode 100644 index 0000000000000000000000000000000000000000..847fe2715173a2569dda1203e3e68ec85150b607 --- /dev/null +++ b/example_images/panorama/Braunschweig_Panoram.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abc31b78f03a0b5254f3735bc3201c28d21b6855708f971ce4b6a740dfbddcba +size 562674 diff --git a/moge/__init__.py b/moge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/moge/model/__init__.py b/moge/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c919e3be42c0005752e8c800129bd5f724b47ff9 --- /dev/null +++ b/moge/model/__init__.py @@ -0,0 +1,18 @@ +import importlib +from typing import * + +if TYPE_CHECKING: + from .v1 import MoGeModel as MoGeModelV1 + from .v2 import MoGeModel as MoGeModelV2 + + +def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]: + assert version in ['v1', 'v2'], f'Unsupported model version: {version}' + + try: + module = importlib.import_module(f'.{version}', __package__) + except ModuleNotFoundError: + raise ValueError(f'Model version "{version}" not found.') + + cls = getattr(module, 'MoGeModel') + return cls diff --git a/moge/model/dinov2/__init__.py b/moge/model/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/moge/model/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/moge/model/dinov2/hub/__init__.py b/moge/model/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/moge/model/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/moge/model/dinov2/hub/backbones.py b/moge/model/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/moge/model/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/moge/model/dinov2/hub/utils.py b/moge/model/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/moge/model/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/moge/model/dinov2/layers/__init__.py b/moge/model/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/moge/model/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/moge/model/dinov2/layers/attention.py b/moge/model/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400 --- /dev/null +++ b/moge/model/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/moge/model/dinov2/layers/block.py b/moge/model/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d --- /dev/null +++ b/moge/model/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/moge/model/dinov2/layers/dino_head.py b/moge/model/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/moge/model/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/moge/model/dinov2/layers/drop_path.py b/moge/model/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/moge/model/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/moge/model/dinov2/layers/layer_scale.py b/moge/model/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/moge/model/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/moge/model/dinov2/layers/mlp.py b/moge/model/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/moge/model/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/moge/model/dinov2/layers/patch_embed.py b/moge/model/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/moge/model/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/moge/model/dinov2/layers/swiglu_ffn.py b/moge/model/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/moge/model/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/moge/model/dinov2/models/__init__.py b/moge/model/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/moge/model/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/moge/model/dinov2/models/vision_transformer.py b/moge/model/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1007ba57ddb35109c91716f1f5bf203db346e7be --- /dev/null +++ b/moge/model/dinov2/models/vision_transformer.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/moge/model/dinov2/utils/__init__.py b/moge/model/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/moge/model/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/moge/model/dinov2/utils/cluster.py b/moge/model/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/moge/model/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/moge/model/dinov2/utils/config.py b/moge/model/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/moge/model/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/moge/model/dinov2/utils/dtype.py b/moge/model/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/moge/model/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/moge/model/dinov2/utils/param_groups.py b/moge/model/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/moge/model/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/moge/model/dinov2/utils/utils.py b/moge/model/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/moge/model/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/moge/model/modules.py b/moge/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2b6731993a10bae0348a928b6018533cabcc1551 --- /dev/null +++ b/moge/model/modules.py @@ -0,0 +1,250 @@ +from typing import * +from numbers import Number +import importlib +import itertools +import functools +import sys + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F + +from .dinov2.models.vision_transformer import DinoVisionTransformer +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from ..utils.geometry_torch import normalized_view_plane_uv + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + hidden_channels: int = None, + kernel_size: int = 3, + padding_mode: str = 'replicate', + activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', + in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm', + hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm', + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = nn.ReLU + elif activation == 'leaky_relu': + activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2) + elif activation =='silu': + activation_cls = nn.SiLU + elif activation == 'elu': + activation_cls = nn.ELU + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \ + nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \ + nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \ + nn.Identity(), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \ + nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \ + nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\ + nn.Identity(), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class DINOv2Encoder(nn.Module): + "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]." + backbone: DinoVisionTransformer + image_mean: torch.Tensor + image_std: torch.Tensor + dim_features: int + + def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs): + super(DINOv2Encoder, self).__init__() + + self.intermediate_layers = intermediate_layers + + # Load the backbone + self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone) + self.backbone_name = backbone + self.backbone = self.hub_loader(pretrained=False) + + self.dim_features = self.backbone.blocks[0].attn.qkv.in_features + self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers) + + self.output_projections = nn.ModuleList([ + nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,) + for _ in range(self.num_features) + ]) + + self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def init_weights(self): + pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict() + self.backbone.load_state_dict(pretrained_backbone_state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward(self, image: torch.Tensor, token_rows: int, token_cols: int, return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=True) + image_14 = (image_14 - self.image_mean) / self.image_std + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True) + + # Project features to the desired dimensionality + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous()) + for proj, (feat, clstoken) in zip(self.output_projections, features) + ], dim=1).sum(dim=1) + + if return_class_token: + return x, features[-1][1] + else: + return x + + +class Resampler(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], + scale_factor: int = 2, + ): + if type_ == 'pixel_shuffle': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.PixelShuffle(scale_factor), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + for i in range(1, scale_factor ** 2): + self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2] + self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2] + elif type_ in ['nearest', 'bilinear']: + nn.Sequential.__init__(self, + nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + elif type_ == 'conv_transpose': + nn.Sequential.__init__(self, + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1] + elif type_ == 'pixel_unshuffle': + nn.Sequential.__init__(self, + nn.PixelUnshuffle(scale_factor), + nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + elif type_ == 'avg_pool': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + elif type_ == 'max_pool': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + else: + raise ValueError(f'Unsupported resampler type: {type_}') + +class MLP(nn.Sequential): + def __init__(self, dims: Sequence[int]): + nn.Sequential.__init__(self, + *itertools.chain(*[ + (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True)) + for dim_in, dim_out in zip(dims[:-2], dims[1:-1]) + ]), + nn.Linear(dims[-2], dims[-1]), + ) + + +class ConvStack(nn.Module): + def __init__(self, + dim_in: List[Optional[int]], + dim_res_blocks: List[int], + dim_out: List[Optional[int]], + resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm', + res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm', + activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', + ): + super().__init__() + self.input_blocks = nn.ModuleList([ + nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity() + for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks) + ]) + self.resamplers = nn.ModuleList([ + Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler) + for i, (dim_prev, dim_succ, resampler) in enumerate(zip( + dim_res_blocks[:-1], + dim_res_blocks[1:], + resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers) + )) + ]) + self.res_blocks = nn.ModuleList([ + nn.Sequential( + *( + ResidualConvBlock( + dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_, + activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm + ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks) + ) + ) for i, dim_res_block_ in enumerate(dim_res_blocks) + ]) + self.output_blocks = nn.ModuleList([ + nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity() + for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks) + ]) + + def enable_gradient_checkpointing(self): + for i in range(len(self.resamplers)): + self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i]) + for i in range(len(self.res_blocks)): + for j in range(len(self.res_blocks[i])): + self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j]) + + def forward(self, in_features: List[torch.Tensor]): + batch_shape = in_features[0].shape[:-3] + in_features = [x.reshape(-1, *x.shape[-3:]) for x in in_features] + + out_features = [] + for i in range(len(self.res_blocks)): + feature = self.input_blocks[i](in_features[i]) + if i == 0: + x = feature + elif feature is not None: + x = x + feature + x = self.res_blocks[i](x) + out_features.append(self.output_blocks[i](x)) + if i < len(self.res_blocks) - 1: + x = self.resamplers[i](x) + + out_features = [x.unflatten(0, batch_shape) for x in out_features] + return out_features diff --git a/moge/model/utils.py b/moge/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c50761d8740d9d0a0284e129503b8931c6fe08c4 --- /dev/null +++ b/moge/model/utils.py @@ -0,0 +1,49 @@ +from typing import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + module.__class__ = _AttentionWrapper + return module + + +def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]: + group_to_use = torch.distributed.group.WORLD + world_size = group_to_use.size() + grad = bucket.buffer() + grad.div_(world_size) + torch.distributed.all_reduce(grad, group=group_to_use) + fut = torch.futures.Future() + fut.set_result(grad) + return fut diff --git a/moge/model/v1.py b/moge/model/v1.py new file mode 100644 index 0000000000000000000000000000000000000000..1c14cc7ab3e03e9eed310fd547fc85d9e2a6ad9e --- /dev/null +++ b/moge/model/v1.py @@ -0,0 +1,392 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import importlib +import warnings +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +import utils3d +from huggingface_hub import hf_hub_download + + +from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from ..utils.tools import timeit + + +class ResidualConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = lambda: nn.ReLU(inplace=True) + elif activation == 'leaky_relu': + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) + elif activation =='silu': + activation_cls = lambda: nn.SiLU(inplace=True) + elif activation == 'elu': + activation_cls = lambda: nn.ELU(inplace=True) + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class Head(nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1 + ): + super().__init__() + + self.projects = nn.ModuleList([ + nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features) + ]) + + self.upsample_blocks = nn.ModuleList([ + nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks)) + ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ]) + + self.output_block = nn.ModuleList([ + self._make_output_block( + dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm, + ) for dim_out_ in dim_out + ]) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']): + return nn.Sequential( + nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)), + nn.ReLU(inplace=True), + nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'), + ) + + def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + + # Process the hidden states + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) + for proj, (feat, clstoken) in zip(self.projects, hidden_states) + ], dim=1).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + if isinstance(self.output_block, nn.ModuleList): + output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block] + else: + output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) + + return output + + +class MoGeModel(nn.Module): + image_mean: torch.Tensor + image_std: torch.Tensor + + def __init__(self, + encoder: str = 'dinov2_vitb14', + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + num_tokens_range: Tuple[Number, Number] = [1200, 2500], + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + mask_threshold: float = 0.5, + **deprecated_kwargs + ): + super(MoGeModel, self).__init__() + + if deprecated_kwargs: + # Process legacy arguments + if 'trained_area_range' in deprecated_kwargs: + num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2] + del deprecated_kwargs['trained_area_range'] + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.encoder = encoder + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.num_tokens_range = num_tokens_range + self.mask_threshold = mask_threshold + + # NOTE: We have copied the DINOv2 code in torchhub to this repository. + # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues. + hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) + self.backbone = hub_loader(pretrained=False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + self.head = Head( + num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers), + dim_in=dim_feature, + dim_out=[3, 1], + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size + ) + + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel': + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True) + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True) + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model']) + return model + + def init_weights(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == 'linear': + pass + elif self.remap_output =='sinh': + points = torch.sinh(points) + elif self.remap_output == 'exp': + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output =='sinh_exp': + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + original_height, original_width = image.shape[-2:] + + # Resize to expected resolution defined by num_tokens + resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5 + resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor) + image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True) + + # Apply image transformation for DINOv2 + image = (image - self.image_mean) / self.image_std + image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True) + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) + + # Predict points (and mask) + output = self.head(features, image) + points, mask = output + + # Make sure fp32 precision for output + with torch.autocast(device_type=image.device.type, dtype=torch.float32): + # Resize to original resolution + points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False) + mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False) + + # Post-process points and mask + points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1) + points = self._remap_points(points) # slightly improves the performance in case of very large output values + + return_dict = {'points': points, 'mask': mask} + return return_dict + + @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + fov_x: Union[Number, torch.Tensor] = None, + resolution_level: int = 9, + num_tokens: int = None, + apply_mask: bool = True, + force_projection: bool = True, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\ + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `resolution_level`: An integer [0-9] for the resolution level for inference. + The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. + `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details. + - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. + `resolution_level` will be ignored if `num_tokens` is provided. Default: None + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + aspect_ratio = original_width / original_height + + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16): + output = self.forward(image, num_tokens) + points, mask = output['points'], output['mask'] + + # Always process the output in fp32 precision + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x]) + + mask_binary = mask > self.mask_threshold + + # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) + if fov_x is None: + focal, shift = recover_focal_shift(points, mask_binary) + else: + focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio + fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 + intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + depth = points[..., 2] + shift[..., None, None] + + # If projection constraint is forced, recompute the point map using the actual depth map + if force_projection: + points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) + else: + points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :] + + # Apply mask if needed + if apply_mask: + points = torch.where(mask_binary[..., None], points, torch.inf) + depth = torch.where(mask_binary, depth, torch.inf) + + return_dict = { + 'points': points, + 'intrinsics': intrinsics, + 'depth': depth, + 'mask': mask_binary, + } + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict \ No newline at end of file diff --git a/moge/model/v2.py b/moge/model/v2.py new file mode 100644 index 0000000000000000000000000000000000000000..eee351c6d50fc2f4bc7f0169bb3d44b9cca6a7a9 --- /dev/null +++ b/moge/model/v2.py @@ -0,0 +1,290 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.amp +import torch.version +import utils3d +from huggingface_hub import hf_hub_download + +from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3 +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from .modules import DINOv2Encoder, MLP, ConvStack + + +class MoGeModel(nn.Module): + encoder: DINOv2Encoder + neck: ConvStack + points_head: ConvStack + mask_head: ConvStack + scale_head: MLP + + def __init__(self, + encoder: Dict[str, Any], + neck: Dict[str, Any], + points_head: Dict[str, Any] = None, + mask_head: Dict[str, Any] = None, + normal_head: Dict[str, Any] = None, + scale_head: Dict[str, Any] = None, + remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', + num_tokens_range: List[int] = [1200, 3600], + **deprecated_kwargs + ): + super(MoGeModel, self).__init__() + if deprecated_kwargs: + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.remap_output = remap_output + self.num_tokens_range = num_tokens_range + + self.encoder = DINOv2Encoder(**encoder) + self.neck = ConvStack(**neck) + if points_head is not None: + self.points_head = ConvStack(**points_head) + if mask_head is not None: + self.mask_head = ConvStack(**mask_head) + if normal_head is not None: + self.normal_head = ConvStack(**normal_head) + if scale_head is not None: + self.scale_head = MLP(**scale_head) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel': + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `compiled` + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint_path = pretrained_model_name_or_path + else: + checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model'], strict=False) + + return model + + def init_weights(self): + self.encoder.init_weights() + + def enable_gradient_checkpointing(self): + self.encoder.enable_gradient_checkpointing() + self.neck.enable_gradient_checkpointing() + for head in ['points_head', 'normal_head', 'mask_head']: + if hasattr(self, head): + getattr(self, head).enable_gradient_checkpointing() + + def enable_pytorch_native_sdpa(self): + self.encoder.enable_pytorch_native_sdpa() + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == 'linear': + pass + elif self.remap_output =='sinh': + points = torch.sinh(points) + elif self.remap_output == 'exp': + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output =='sinh_exp': + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + batch_size, _, img_h, img_w = image.shape + device, dtype = image.device, image.dtype + + aspect_ratio = img_w / img_h + base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5) + num_tokens = base_h * base_w + + # Backbones encoding + features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True) + features = [features, None, None, None, None] + + # Concat UVs for aspect ratio input + for level in range(5): + uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1) + if features[level] is None: + features[level] = uv + else: + features[level] = torch.concat([features[level], uv], dim=1) + + # Shared neck + features = self.neck(features) + + # Heads decoding + points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head']) + metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None + + # Resize + points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask]) + + # Remap output + if points is not None: + points = points.permute(0, 2, 3, 1) + points = self._remap_points(points) # slightly improves the performance in case of very large output values + if normal is not None: + normal = normal.permute(0, 2, 3, 1) + normal = F.normalize(normal, dim=-1) + if mask is not None: + mask = mask.squeeze(1).sigmoid() + if metric_scale is not None: + metric_scale = metric_scale.squeeze(1).exp() + + return_dict = { + 'points': points, + 'normal': normal, + 'mask': mask, + 'metric_scale': metric_scale + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + return return_dict + + @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + num_tokens: int = None, + resolution_level: int = 9, + force_projection: bool = True, + apply_mask: Literal[False, True, 'blend'] = True, + fov_x: Optional[Union[Number, torch.Tensor]] = None, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W) + - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500. + More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`. + - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + area = original_height * original_width + aspect_ratio = original_width / original_height + + # Determine the number of base tokens to use + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) + + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16): + output = self.forward(image, num_tokens=num_tokens) + points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale']) + + # Always process the output in fp32 precision + points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x]) + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + if mask is not None: + mask_binary = mask > 0.5 + else: + mask_binary = None + + if points is not None: + # Convert affine point map to camera-space. Recover depth and intrinsics from point map. + # NOTE: Focal here is the focal length relative to half the image diagonal + if fov_x is None: + # Recover focal and shift from predicted point map + focal, shift = recover_focal_shift(points, mask_binary) + else: + # Focal is known, recover shift only + focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 + intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + points[..., 2] += shift[..., None, None] + if mask_binary is not None: + mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice) + depth = points[..., 2].clone() + else: + depth, intrinsics = None, None + + # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics + if force_projection and depth is not None: + points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) + + # Apply metric scale + if metric_scale is not None: + if points is not None: + points *= metric_scale[:, None, None, None] + if depth is not None: + depth *= metric_scale[:, None, None] + + # Apply mask + if apply_mask and mask_binary is not None: + points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None + depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None + normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None + + return_dict = { + 'points': points, + 'intrinsics': intrinsics, + 'depth': depth, + 'mask': mask_binary, + 'normal': normal + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict diff --git a/moge/scripts/__init__.py b/moge/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/moge/scripts/cli.py b/moge/scripts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..45c3b9006bf56306e403f8da5b6d5068215221ee --- /dev/null +++ b/moge/scripts/cli.py @@ -0,0 +1,27 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) + +import click + + +@click.group(help='MoGe command line interface.') +def cli(): + pass + +def main(): + from moge.scripts import app, infer, infer_baseline, infer_panorama, eval_baseline, vis_data + cli.add_command(app.main, name='app') + cli.add_command(infer.main, name='infer') + cli.add_command(infer_baseline.main, name='infer_baseline') + cli.add_command(infer_panorama.main, name='infer_panorama') + cli.add_command(eval_baseline.main, name='eval_baseline') + cli.add_command(vis_data.main, name='vis_data') + cli() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/moge/scripts/eval_baseline.py b/moge/scripts/eval_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..8217d9e6500b1d72a00e1a0a225ba4c2134b892e --- /dev/null +++ b/moge/scripts/eval_baseline.py @@ -0,0 +1,165 @@ +import os +import sys +from pathlib import Path +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +import json +from typing import * +import importlib +import importlib.util + +import click + + +@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Evaluation script.') +@click.option('--baseline', 'baseline_code_path', type=click.Path(), required=True, help='Path to the baseline model python code.') +@click.option('--config', 'config_path', type=click.Path(), default='configs/eval/all_benchmarks.json', help='Path to the evaluation configurations. ' + 'Defaults to "configs/eval/all_benchmarks.json".') +@click.option('--output', '-o', 'output_path', type=click.Path(), required=True, help='Path to the output json file.') +@click.option('--oracle', 'oracle_mode', is_flag=True, help='Use oracle mode for evaluation, i.e., use the GT intrinsics input.') +@click.option('--dump_pred', is_flag=True, help='Dump predition results.') +@click.option('--dump_gt', is_flag=True, help='Dump ground truth.') +@click.pass_context +def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_mode: bool, output_path: Union[str, Path], dump_pred: bool, dump_gt: bool): + # Lazy import + import cv2 + import numpy as np + from tqdm import tqdm + import torch + import torch.nn.functional as F + import utils3d + + from moge.test.baseline import MGEBaselineInterface + from moge.test.dataloader import EvalDataLoaderPipeline + from moge.test.metrics import compute_metrics + from moge.utils.geometry_torch import intrinsics_to_fov + from moge.utils.vis import colorize_depth, colorize_normal + from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module + + # Load the baseline model + module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) + baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline') + baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False) + + # Load the evaluation configurations + with open(config_path, 'r') as f: + config = json.load(f) + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + all_metrics = {} + # Iterate over the dataset + for benchmark_name, benchmark_config in tqdm(list(config.items()), desc='Benchmarks'): + filenames, metrics_list = [], [] + with ( + EvalDataLoaderPipeline(**benchmark_config) as eval_data_pipe, + tqdm(total=len(eval_data_pipe), desc=benchmark_name, leave=False) as pbar + ): + # Iterate over the samples in the dataset + for i in range(len(eval_data_pipe)): + sample = eval_data_pipe.get() + sample = {k: v.to(baseline.device) if isinstance(v, torch.Tensor) else v for k, v in sample.items()} + image = sample['image'] + gt_intrinsics = sample['intrinsics'] + + # Inference + torch.cuda.synchronize() + with torch.inference_mode(), timeit('_inference_timer', verbose=False) as timer: + if oracle_mode: + pred = baseline.infer_for_evaluation(image, gt_intrinsics) + else: + pred = baseline.infer_for_evaluation(image) + torch.cuda.synchronize() + + # Compute metrics + metrics, misc = compute_metrics(pred, sample, vis=dump_pred or dump_gt) + metrics['inference_time'] = timer.time + metrics_list.append(metrics) + + # Dump results + dump_path = Path(output_path.replace(".json", f"_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', '')) + if dump_pred: + dump_path.joinpath('pred').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(dump_path / 'pred' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + with Path(dump_path, 'pred', 'metrics.json').open('w') as f: + json.dump(metrics, f, indent=4) + + if 'pred_points' in misc: + points = misc['pred_points'].cpu().numpy() + cv2.imwrite(str(dump_path / 'pred' / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + + if 'pred_depth' in misc: + depth = misc['pred_depth'].cpu().numpy() + if 'mask' in pred: + mask = pred['mask'].cpu().numpy() + depth = np.where(mask, depth, np.inf) + cv2.imwrite(str(dump_path / 'pred' / 'depth.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) + + if 'mask' in pred: + mask = pred['mask'].cpu().numpy() + cv2.imwrite(str(dump_path / 'pred' / 'mask.png'), (mask * 255).astype(np.uint8)) + + if 'normal' in pred: + normal = pred['normal'].cpu().numpy() + cv2.imwrite(str(dump_path / 'pred' / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR)) + + if 'intrinsics' in pred: + intrinsics = pred['intrinsics'] + fov_x, fov_y = intrinsics_to_fov(intrinsics) + with open(dump_path / 'pred' / 'fov.json', 'w') as f: + json.dump({ + 'fov_x': np.rad2deg(fov_x.item()), + 'fov_y': np.rad2deg(fov_y.item()), + 'intrinsics': intrinsics.cpu().numpy().tolist(), + }, f) + + if dump_gt: + dump_path.joinpath('gt').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(dump_path / 'gt' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + if 'points' in sample: + points = sample['points'] + cv2.imwrite(str(dump_path / 'gt' / 'points.exr'), cv2.cvtColor(points.cpu().numpy().astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + + if 'depth' in sample: + depth = sample['depth'] + mask = sample['depth_mask'] + cv2.imwrite(str(dump_path / 'gt' / 'depth.png'), cv2.cvtColor(colorize_depth(depth.cpu().numpy(), mask=mask.cpu().numpy()), cv2.COLOR_RGB2BGR)) + + if 'normal' in sample: + normal = sample['normal'] + cv2.imwrite(str(dump_path / 'gt' / 'normal.png'), cv2.cvtColor(colorize_normal(normal.cpu().numpy()), cv2.COLOR_RGB2BGR)) + + if 'depth_mask' in sample: + mask = sample['depth_mask'] + cv2.imwrite(str(dump_path / 'gt' /'mask.png'), (mask.cpu().numpy() * 255).astype(np.uint8)) + + if 'intrinsics' in sample: + intrinsics = sample['intrinsics'] + fov_x, fov_y = intrinsics_to_fov(intrinsics) + with open(dump_path / 'gt' / 'info.json', 'w') as f: + json.dump({ + 'fov_x': np.rad2deg(fov_x.item()), + 'fov_y': np.rad2deg(fov_y.item()), + 'intrinsics': intrinsics.cpu().numpy().tolist(), + }, f) + + # Save intermediate results + if i % 100 == 0 or i == len(eval_data_pipe) - 1: + Path(output_path).write_text( + json.dumps({ + **all_metrics, + benchmark_name: key_average(metrics_list) + }, indent=4) + ) + pbar.update(1) + + all_metrics[benchmark_name] = key_average(metrics_list) + + # Save final results + all_metrics['mean'] = key_average(list(all_metrics.values())) + Path(output_path).write_text(json.dumps(all_metrics, indent=4)) + + +if __name__ == '__main__': + main() diff --git a/moge/scripts/infer.py b/moge/scripts/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..a66dd553cd70d6f043a99425ee993bf26bedb0fd --- /dev/null +++ b/moge/scripts/infer.py @@ -0,0 +1,170 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +from typing import * +import itertools +import json +import warnings + + +import click + + +@click.command(help='Inference script') +@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.') +@click.option('--fov_x', 'fov_x_', type=float, default=None, help='If camera parameters are known, set the horizontal field of view in degrees. Otherwise, MoGe will estimate it.') +@click.option('--output', '-o', 'output_path', default='./output', type=click.Path(), help='Output folder path') +@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default=None, help='Pretrained model name or path. If not provided, the corresponding default model will be chosen.') +@click.option('--version', 'model_version', type=click.Choice(['v1', 'v2']), default='v2', help='Model version. Defaults to "v2"') +@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"') +@click.option('--fp16', 'use_fp16', is_flag=True, help='Use fp16 precision for much faster inference.') +@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).') +@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level for inference. \ +Higher value means more tokens and the finer details will be captured, but inference can be slower. \ +Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. \ +`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.') +@click.option('--num_tokens', type=int, default=None, help='number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. \ +`resolution_level` will be ignored if `num_tokens` is provided. Default: None') +@click.option('--threshold', type=float, default=0.01, help='Threshold for removing edges. Defaults to 0.01. Smaller value removes more edges. "inf" means no thresholding.') +@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps (image, point map, depth map, normal map, mask) and fov.') +@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') +@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') +@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') +def main( + input_path: str, + fov_x_: float, + output_path: str, + pretrained_model_name_or_path: str, + model_version: str, + device_name: str, + use_fp16: bool, + resize_to: int, + resolution_level: int, + num_tokens: int, + threshold: float, + save_maps_: bool, + save_glb_: bool, + save_ply_: bool, + show: bool, +): + import cv2 + import numpy as np + import torch + from PIL import Image + from tqdm import tqdm + import trimesh + import trimesh.visual + import click + + from moge.model import import_model_class_by_version + from moge.utils.io import save_glb, save_ply + from moge.utils.vis import colorize_depth, colorize_normal + from moge.utils.geometry_numpy import depth_occlusion_edge_numpy + import utils3d + + device = torch.device(device_name) + + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if len(image_paths) == 0: + raise FileNotFoundError(f'No image files found in {input_path}') + + if pretrained_model_name_or_path is None: + DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = { + "v1": "Ruicheng/moge-vitl", + "v2": "Ruicheng/moge-2-vitl-normal", + } + pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version] + model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).to(device).eval() + if use_fp16: + model.half() + + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') + save_maps_ = save_glb_ = save_ply_ = True + + for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): + image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image.shape[:2] + if resize_to is not None: + height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) + image = cv2.resize(image, (width, height), cv2.INTER_AREA) + image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) + + # Inference + output = model.infer(image_tensor, fov_x=fov_x_, resolution_level=resolution_level, num_tokens=num_tokens, use_fp16=use_fp16) + points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy() + normal = output['normal'].cpu().numpy() if 'normal' in output else None + + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + save_path.mkdir(exist_ok=True, parents=True) + + # Save images / maps + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8)) + cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + if normal is not None: + cv2.imwrite(str(save_path / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR)) + fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) + with open(save_path / 'fov.json', 'w') as f: + json.dump({ + 'fov_x': round(float(np.rad2deg(fov_x)), 2), + 'fov_y': round(float(np.rad2deg(fov_y)), 2), + }, f) + + # Export mesh & visulization + if save_glb_ or save_ply_ or show: + mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=0.04) + if normal is None: + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask_cleaned, + tri=True + ) + vertex_normals = None + else: + faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + normal, + mask=mask_cleaned, + tri=True + ) + # When exporting the model, follow the OpenGL coordinate conventions: + # - world coordinate system: x right, y up, z backward. + # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. + vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] + if normal is not None: + vertex_normals = vertex_normals * [1, -1, -1] + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image, vertex_normals) + + if save_ply_: + save_ply(save_path / 'pointcloud.ply', vertices, np.zeros((0, 3), dtype=np.int32), vertex_colors, vertex_normals) + + if show: + trimesh.Trimesh( + vertices=vertices, + vertex_colors=vertex_colors, + vertex_normals=vertex_normals, + faces=faces, + process=False + ).show() + + +if __name__ == '__main__': + main() diff --git a/moge/scripts/infer_baseline.py b/moge/scripts/infer_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..5409674f7cd5ce21de9200fd9038cb7d71c99e0f --- /dev/null +++ b/moge/scripts/infer_baseline.py @@ -0,0 +1,140 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +import json +from pathlib import Path +from typing import * +import itertools +import warnings + +import click + + +@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Inference script for wrapped baselines methods') +@click.option('--baseline', 'baseline_code_path', required=True, type=click.Path(), help='Path to the baseline model python code.') +@click.option('--input', '-i', 'input_path', type=str, required=True, help='Input image or folder') +@click.option('--output', '-o', 'output_path', type=str, default='./output', help='Output folder') +@click.option('--size', 'image_size', type=int, default=None, help='Resize input image') +@click.option('--skip', is_flag=True, help='Skip existing output') +@click.option('--maps', 'save_maps_', is_flag=True, help='Save output point / depth maps') +@click.option('--ply', 'save_ply_', is_flag=True, help='Save mesh in PLY format') +@click.option('--glb', 'save_glb_', is_flag=True, help='Save mesh in GLB format') +@click.option('--threshold', type=float, default=0.03, help='Depth edge detection threshold for saving mesh') +@click.pass_context +def main(ctx: click.Context, baseline_code_path: str, input_path: str, output_path: str, image_size: int, skip: bool, save_maps_, save_ply_: bool, save_glb_: bool, threshold: float): + # Lazy import + import cv2 + import numpy as np + from tqdm import tqdm + import torch + import utils3d + + from moge.utils.io import save_ply, save_glb + from moge.utils.geometry_numpy import intrinsics_to_fov_numpy + from moge.utils.vis import colorize_depth, colorize_depth_affine, colorize_disparity + from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module + from moge.test.baseline import MGEBaselineInterface + + # Load the baseline model + module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) + baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline') + baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False) + + # Input images list + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Defaults to saving maps only. Please use "--maps", "--glb", or "--ply" to specify the output.') + save_maps_ = True + + for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): + # Load one image at a time + image_np = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image_np.shape[:2] + if image_size is not None and max(image_np.shape[:2]) > image_size: + height, width = min(image_size, int(image_size * height / width)), min(image_size, int(image_size * width / height)) + image_np = cv2.resize(image_np, (width, height), cv2.INTER_AREA) + image = torch.from_numpy(image_np.astype(np.float32) / 255.0).permute(2, 0, 1).to(baseline.device) + + # Inference + torch.cuda.synchronize() + with torch.inference_mode(), (timer := timeit('Inference', verbose=False, average=True)): + output = baseline.infer(image) + torch.cuda.synchronize() + + inference_time = timer.average_time + pbar.set_postfix({'average inference time': f'{inference_time:.3f}s'}) + + # Save the output + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + if skip and save_path.exists(): + continue + save_path.mkdir(parents=True, exist_ok=True) + + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) + + if 'mask' in output: + mask = output['mask'].cpu().numpy() + cv2.imwrite(str(save_path /'mask.png'), (mask * 255).astype(np.uint8)) + + for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']: + if k in output: + points = output[k].cpu().numpy() + cv2.imwrite(str(save_path / f'{k}.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + + for k in ['depth_metric', 'depth_scale_invariant', 'depth_affine_invariant', 'disparity_affine_invariant']: + if k in output: + depth = output[k].cpu().numpy() + cv2.imwrite(str(save_path / f'{k}.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + if k in ['depth_metric', 'depth_scale_invariant']: + depth_vis = colorize_depth(depth) + elif k == 'depth_affine_invariant': + depth_vis = colorize_depth_affine(depth) + elif k == 'disparity_affine_invariant': + depth_vis = colorize_disparity(depth) + cv2.imwrite(str(save_path / f'{k}_vis.png'), cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR)) + + if 'intrinsics' in output: + intrinsics = output['intrinsics'].cpu().numpy() + fov_x, fov_y = intrinsics_to_fov_numpy(intrinsics) + with open(save_path / 'fov.json', 'w') as f: + json.dump({ + 'fov_x': float(np.rad2deg(fov_x)), + 'fov_y': float(np.rad2deg(fov_y)), + 'intrinsics': intrinsics.tolist() + }, f, indent=4) + + # Export mesh & visulization + if save_ply_ or save_glb_: + assert any(k in output for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']), 'No point map found in output' + points = next(output[k] for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant'] if k in output).cpu().numpy() + mask = output['mask'] if 'mask' in output else np.ones_like(points[..., 0], dtype=bool) + normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image_np.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), + tri=True + ) + # When exporting the model, follow the OpenGL coordinate conventions: + # - world coordinate system: x right, y up, z backward. + # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. + vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image_np) + + if save_ply_: + save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) + +if __name__ == '__main__': + main() diff --git a/moge/scripts/infer_panorama.py b/moge/scripts/infer_panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..cce65cb90cd1c6750d42cdda4e72d4ce3a2c0549 --- /dev/null +++ b/moge/scripts/infer_panorama.py @@ -0,0 +1,162 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +from typing import * +import itertools +import json +import warnings + +import click + + +@click.command(help='Inference script for panorama images') +@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.') +@click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path') +@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"') +@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"') +@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).') +@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Defaults to 9. Note that it is irrelevant to the output resolution.') +@click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.') +@click.option('--batch_size', type=int, default=4, help='Batch size for inference. Defaults to 4.') +@click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Defaults to False.') +@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).') +@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') +@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') +@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') +def main( + input_path: str, + output_path: str, + pretrained_model_name_or_path: str, + device_name: str, + resize_to: int, + resolution_level: int, + threshold: float, + batch_size: int, + save_splitted: bool, + save_maps_: bool, + save_glb_: bool, + save_ply_: bool, + show: bool, +): + # Lazy import + import cv2 + import numpy as np + from numpy import ndarray + import torch + from PIL import Image + from tqdm import tqdm, trange + import trimesh + import trimesh.visual + from scipy.sparse import csr_array, hstack, vstack + from scipy.ndimage import convolve + from scipy.sparse.linalg import lsmr + + import utils3d + from moge.model.v1 import MoGeModel + from moge.utils.io import save_glb, save_ply + from moge.utils.vis import colorize_depth + from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth + + + device = torch.device(device_name) + + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if len(image_paths) == 0: + raise FileNotFoundError(f'No image files found in {input_path}') + + # Write outputs + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') + save_maps_ = save_glb_ = save_ply_ = True + + model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() + + for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)): + image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image.shape[:2] + if resize_to is not None: + height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) + image = cv2.resize(image, (width, height), cv2.INTER_AREA) + + splitted_extrinsics, splitted_intriniscs = get_panorama_cameras() + splitted_resolution = 512 + splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution) + + # Infer each view + print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + + splitted_distance_maps, splitted_masks = [], [] + for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False): + image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2) + fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size]))) + fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device) + output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False) + distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy() + splitted_distance_maps.extend(list(distance_map)) + splitted_masks.extend(list(mask)) + + # Save splitted + if save_splitted: + splitted_save_path = Path(output_path, image_path.stem, 'splitted') + splitted_save_path.mkdir(exist_ok=True, parents=True) + for i in range(len(splitted_images)): + cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR)) + + # Merge + print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging') + + merging_width, merging_height = min(1920, width), min(960, height) + panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs) + panorama_depth = panorama_depth.astype(np.float32) + panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR) + panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0 + points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height)) + + # Write outputs + print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + save_path.mkdir(exist_ok=True, parents=True) + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8)) + + # Export mesh & visulization + if save_glb_ or save_ply_ or show: + normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask) + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), + tri=True + ) + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) + + if save_ply_: + save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) + + if show: + trimesh.Trimesh( + vertices=vertices, + vertex_colors=vertex_colors, + faces=faces, + process=False + ).show() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/moge/scripts/train.py b/moge/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d96d3ad4d31ed0b6c30bbbbcd83033b227e90829 --- /dev/null +++ b/moge/scripts/train.py @@ -0,0 +1,452 @@ +import os +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +import json +import time +import random +from typing import * +import itertools +from contextlib import nullcontext +from concurrent.futures import ThreadPoolExecutor +import io + +import numpy as np +import cv2 +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.version +import accelerate +from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate.utils import set_seed +import utils3d +import click +from tqdm import tqdm, trange +import mlflow +torch.backends.cudnn.benchmark = False # Varying input size, make sure cudnn benchmark is disabled + +from moge.train.dataloader import TrainDataLoaderPipeline +from moge.train.losses import ( + affine_invariant_global_loss, + affine_invariant_local_loss, + edge_loss, + normal_loss, + mask_l2_loss, + mask_bce_loss, + monitoring, +) +from moge.train.utils import build_optimizer, build_lr_scheduler +from moge.utils.geometry_torch import intrinsics_to_fov +from moge.utils.vis import colorize_depth, colorize_normal +from moge.utils.tools import key_average, recursive_replace, CallbackOnException, flatten_nested_dict +from moge.test.metrics import compute_metrics + + +@click.command() +@click.option('--config', 'config_path', type=str, default='configs/debug.json') +@click.option('--workspace', type=str, default='workspace/debug', help='Path to the workspace') +@click.option('--checkpoint', 'checkpoint_path', type=str, default=None, help='Path to the checkpoint to load') +@click.option('--batch_size_forward', type=int, default=8, help='Batch size for each forward pass on each device') +@click.option('--gradient_accumulation_steps', type=int, default=1, help='Number of steps to accumulate gradients') +@click.option('--enable_gradient_checkpointing', type=bool, default=True, help='Use gradient checkpointing in backbone') +@click.option('--enable_mixed_precision', type=bool, default=False, help='Use mixed precision training. Backbone is converted to FP16') +@click.option('--enable_ema', type=bool, default=True, help='Maintain an exponential moving average of the model weights') +@click.option('--num_iterations', type=int, default=1000000, help='Number of iterations to train the model') +@click.option('--save_every', type=int, default=10000, help='Save checkpoint every n iterations') +@click.option('--log_every', type=int, default=1000, help='Log metrics every n iterations') +@click.option('--vis_every', type=int, default=0, help='Visualize every n iterations') +@click.option('--num_vis_images', type=int, default=32, help='Number of images to visualize, must be a multiple of divided batch size') +@click.option('--enable_mlflow', type=bool, default=True, help='Log metrics to MLFlow') +@click.option('--seed', type=int, default=0, help='Random seed') +def main( + config_path: str, + workspace: str, + checkpoint_path: str, + batch_size_forward: int, + gradient_accumulation_steps: int, + enable_gradient_checkpointing: bool, + enable_mixed_precision: bool, + enable_ema: bool, + num_iterations: int, + save_every: int, + log_every: int, + vis_every: int, + num_vis_images: int, + enable_mlflow: bool, + seed: Optional[int], +): + # Load config + with open(config_path, 'r') as f: + config = json.load(f) + + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision='fp16' if enable_mixed_precision else None, + kwargs_handlers=[ + DistributedDataParallelKwargs(find_unused_parameters=True) + ] + ) + device = accelerator.device + batch_size_total = batch_size_forward * gradient_accumulation_steps * accelerator.num_processes + + # Log config + if accelerator.is_main_process: + if enable_mlflow: + try: + mlflow.log_params({ + **click.get_current_context().params, + 'batch_size_total': batch_size_total, + }) + except: + print('Failed to log config to MLFlow') + Path(workspace).mkdir(parents=True, exist_ok=True) + with Path(workspace).joinpath('config.json').open('w') as f: + json.dump(config, f, indent=4) + + # Set seed + if seed is not None: + set_seed(seed, device_specific=True) + + # Initialize model + print('Initialize model') + with accelerator.local_main_process_first(): + from moge.model import import_model_class_by_version + MoGeModel = import_model_class_by_version(config['model_version']) + model = MoGeModel(**config['model']) + count_total_parameters = sum(p.numel() for p in model.parameters()) + print(f'Total parameters: {count_total_parameters}') + + # Set up EMA model + if enable_ema and accelerator.is_main_process: + ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: 0.999 * averaged_model_parameter + 0.001 * model_parameter + ema_model = torch.optim.swa_utils.AveragedModel(model, device=accelerator.device, avg_fn=ema_avg_fn) + + # Set gradient checkpointing + if enable_gradient_checkpointing: + model.enable_gradient_checkpointing() + import warnings + warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint") + + # Initalize optimizer & lr scheduler + optimizer = build_optimizer(model, config['optimizer']) + lr_scheduler = build_lr_scheduler(optimizer, config['lr_scheduler']) + + count_grouped_parameters = [sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups] + for i, count in enumerate(count_grouped_parameters): + print(f'- Group {i}: {count} parameters') + + # Attempt to load checkpoint + checkpoint: Dict[str, Any] + with accelerator.local_main_process_first(): + if checkpoint_path.endswith('.pt'): + # - Load specific checkpoint file + print(f'Load checkpoint: {checkpoint_path}') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + elif checkpoint_path == "latest": + # - Load latest + checkpoint_path = Path(workspace, 'checkpoint', 'latest.pt') + if checkpoint_path.exists(): + print(f'Load checkpoint: {checkpoint_path}') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + i_step = checkpoint['step'] + if 'model' not in checkpoint and (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists(): + print(f'Load model checkpoint: {checkpoint_model_path}') + checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model'] + if 'optimizer' not in checkpoint and (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists(): + print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}') + checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True)) + if enable_ema and accelerator.is_main_process: + if 'ema_model' not in checkpoint and (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists(): + print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}') + checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model'] + else: + checkpoint = None + elif checkpoint_path is not None: + # - Load by step number + i_step = int(checkpoint_path) + checkpoint = {'step': i_step} + if (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists(): + print(f'Load model checkpoint: {checkpoint_model_path}') + checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model'] + if (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists(): + print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}') + checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True)) + if enable_ema and accelerator.is_main_process: + if (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists(): + print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}') + checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model'] + else: + checkpoint = None + + if checkpoint is None: + # Initialize model weights + print('Initialize model weights') + with accelerator.local_main_process_first(): + model.init_weights() + initial_step = 0 + else: + model.load_state_dict(checkpoint['model'], strict=False) + if 'step' in checkpoint: + initial_step = checkpoint['step'] + 1 + else: + initial_step = 0 + if 'optimizer' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + if enable_ema and accelerator.is_main_process and 'ema_model' in checkpoint: + ema_model.module.load_state_dict(checkpoint['ema_model'], strict=False) + if 'lr_scheduler' in checkpoint: + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + + del checkpoint + + model, optimizer = accelerator.prepare(model, optimizer) + if torch.version.hip and isinstance(model, torch.nn.parallel.DistributedDataParallel): + # Hacking potential gradient synchronization issue in ROCm backend + from moge.model.utils import sync_ddp_hook + model.register_comm_hook(None, sync_ddp_hook) + + # Initialize training data pipeline + with accelerator.local_main_process_first(): + train_data_pipe = TrainDataLoaderPipeline(config['data'], batch_size_forward) + + def _write_bytes_retry_loop(save_path: Path, data: bytes): + while True: + try: + save_path.write_bytes(data) + break + except Exception as e: + print('Error while saving checkpoint, retrying in 1 minute: ', e) + time.sleep(60) + + # Ready to train + records = [] + model.train() + with ( + train_data_pipe, + tqdm(initial=initial_step, total=num_iterations, desc='Training', disable=not accelerator.is_main_process) as pbar, + ThreadPoolExecutor(max_workers=1) as save_checkpoint_executor, + ): + # Get some batches for visualization + if accelerator.is_main_process: + batches_for_vis: List[Dict[str, torch.Tensor]] = [] + num_vis_images = num_vis_images // batch_size_forward * batch_size_forward + for _ in range(num_vis_images // batch_size_forward): + batch = train_data_pipe.get() + batches_for_vis.append(batch) + + # Visualize GT + if vis_every > 0 and accelerator.is_main_process and initial_step == 0: + save_dir = Path(workspace).joinpath('vis/gt') + for i_batch, batch in enumerate(tqdm(batches_for_vis, desc='Visualize GT', leave=False)): + image, gt_depth, gt_mask, gt_mask_inf, gt_intrinsics, info = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_inf'], batch['intrinsics'], batch['info'] + gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics) + gt_normal, gt_normal_mask = utils3d.torch.points_to_normals(gt_points, gt_mask) + for i_instance in range(batch['image'].shape[0]): + idx = i_batch * batch_size_forward + i_instance + image_i = (image[i_instance].numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + gt_depth_i = gt_depth[i_instance].numpy() + gt_mask_i = gt_mask[i_instance].numpy() + gt_mask_inf_i = gt_mask_inf[i_instance].numpy() + gt_points_i = gt_points[i_instance].numpy() + gt_normal_i = gt_normal[i_instance].numpy() + save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(gt_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), gt_mask_i * 255) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(gt_depth_i, gt_mask_i), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/normal.png')), cv2.cvtColor(colorize_normal(gt_normal_i), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask_inf.png')), gt_mask_inf_i * 255) + with save_dir.joinpath(f'{idx:04d}/info.json').open('w') as f: + json.dump(info[i_instance], f) + + # Reset seed to avoid training on the same data when resuming training + if seed is not None: + set_seed(seed + initial_step, device_specific=True) + + # Training loop + for i_step in range(initial_step, num_iterations): + + i_accumulate, weight_accumulate = 0, 0 + while i_accumulate < gradient_accumulation_steps: + # Load batch + batch = train_data_pipe.get() + image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics, label_type, is_metric = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_fin'], batch['depth_mask_inf'], batch['intrinsics'], batch['label_type'], batch['is_metric'] + image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_mask_fin.to(device), gt_mask_inf.to(device), gt_intrinsics.to(device) + current_batch_size = image.shape[0] + if all(label == 'invalid' for label in label_type): + continue # NOTE: Skip all-invalid batches to avoid messing up the optimizer. + + gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics) + gt_focal = 1 / (1 / gt_intrinsics[..., 0, 0] ** 2 + 1 / gt_intrinsics[..., 1, 1] ** 2) ** 0.5 + + with accelerator.accumulate(model): + # Forward + if i_step <= config.get('low_resolution_training_steps', 0): + num_tokens = config['model']['num_tokens_range'][0] + else: + num_tokens = accelerate.utils.broadcast_object_list([random.randint(*config['model']['num_tokens_range'])])[0] + with torch.autocast(device_type=accelerator.device.type, dtype=torch.float16, enabled=enable_mixed_precision): + output = model(image, num_tokens=num_tokens) + pred_points, pred_mask, pred_metric_scale = output['points'], output['mask'], output.get('metric_scale', None) + + # Compute loss (per instance) + loss_list, weight_list = [], [] + for i in range(current_batch_size): + gt_metric_scale = None + loss_dict, weight_dict, misc_dict = {}, {}, {} + misc_dict['monitoring'] = monitoring(pred_points[i]) + for k, v in config['loss'][label_type[i]].items(): + weight_dict[k] = v['weight'] + if v['function'] == 'affine_invariant_global_loss': + loss_dict[k], misc_dict[k], gt_metric_scale = affine_invariant_global_loss(pred_points[i], gt_points[i], gt_mask[i], **v['params']) + elif v['function'] == 'affine_invariant_local_loss': + loss_dict[k], misc_dict[k] = affine_invariant_local_loss(pred_points[i], gt_points[i], gt_mask[i], gt_focal[i], gt_metric_scale, **v['params']) + elif v['function'] == 'normal_loss': + loss_dict[k], misc_dict[k] = normal_loss(pred_points[i], gt_points[i], gt_mask[i]) + elif v['function'] == 'edge_loss': + loss_dict[k], misc_dict[k] = edge_loss(pred_points[i], gt_points[i], gt_mask[i]) + elif v['function'] == 'mask_bce_loss': + loss_dict[k], misc_dict[k] = mask_bce_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i]) + elif v['function'] == 'mask_l2_loss': + loss_dict[k], misc_dict[k] = mask_l2_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i]) + else: + raise ValueError(f'Undefined loss function: {v["function"]}') + weight_dict = {'.'.join(k): v for k, v in flatten_nested_dict(weight_dict).items()} + loss_dict = {'.'.join(k): v for k, v in flatten_nested_dict(loss_dict).items()} + loss_ = sum([weight_dict[k] * loss_dict[k] for k in loss_dict], start=torch.tensor(0.0, device=device)) + loss_list.append(loss_) + + if torch.isnan(loss_).item(): + pbar.write(f'NaN loss in process {accelerator.process_index}') + pbar.write(str(loss_dict)) + + misc_dict = {'.'.join(k): v for k, v in flatten_nested_dict(misc_dict).items()} + records.append({ + **{k: v.item() for k, v in loss_dict.items()}, + **misc_dict, + }) + + loss = sum(loss_list) / len(loss_list) + + # Backward & update + accelerator.backward(loss) + if accelerator.sync_gradients: + if not enable_mixed_precision and any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None): + if accelerator.is_main_process: + pbar.write(f'NaN gradients, skip update') + optimizer.zero_grad() + continue + accelerator.clip_grad_norm_(model.parameters(), 1.0) + + optimizer.step() + optimizer.zero_grad() + + i_accumulate += 1 + + lr_scheduler.step() + + # EMA update + if enable_ema and accelerator.is_main_process and accelerator.sync_gradients: + ema_model.update_parameters(model) + + # Log metrics + if i_step == initial_step or i_step % log_every == 0: + records = [key_average(records)] + records = accelerator.gather_for_metrics(records, use_gather_object=True) + if accelerator.is_main_process: + records = key_average(records) + if enable_mlflow: + try: + mlflow.log_metrics(records, step=i_step) + except Exception as e: + print(f'Error while logging metrics to mlflow: {e}') + records = [] + + # Save model weight checkpoint + if accelerator.is_main_process and (i_step % save_every == 0): + # NOTE: Writing checkpoint is done in a separate thread to avoid blocking the main process + pbar.write(f'Save checkpoint: {i_step:08d}') + Path(workspace, 'checkpoint').mkdir(parents=True, exist_ok=True) + + # Model checkpoint + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'model': accelerator.unwrap_model(model).state_dict(), + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}.pt'), checkpoint_bytes + ) + + # Optimizer checkpoint + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'step': i_step, + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt'), checkpoint_bytes + ) + + # EMA model checkpoint + if enable_ema: + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'model': ema_model.module.state_dict(), + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt'), checkpoint_bytes + ) + + # Latest checkpoint + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'step': i_step, + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', 'latest.pt'), checkpoint_bytes + ) + + # Visualize + if vis_every > 0 and accelerator.is_main_process and (i_step == initial_step or i_step % vis_every == 0): + unwrapped_model = accelerator.unwrap_model(model) + save_dir = Path(workspace).joinpath(f'vis/step_{i_step:08d}') + save_dir.mkdir(parents=True, exist_ok=True) + with torch.inference_mode(): + for i_batch, batch in enumerate(tqdm(batches_for_vis, desc=f'Visualize: {i_step:08d}', leave=False)): + image, gt_depth, gt_mask, gt_intrinsics = batch['image'], batch['depth'], batch['depth_mask'], batch['intrinsics'] + image, gt_depth, gt_mask, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_intrinsics.to(device) + + output = unwrapped_model.infer(image) + pred_points, pred_depth, pred_mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy() + image = image.cpu().numpy() + + for i_instance in range(image.shape[0]): + idx = i_batch * batch_size_forward + i_instance + image_i = (image[i_instance].transpose(1, 2, 0) * 255).astype(np.uint8) + pred_points_i = pred_points[i_instance] + pred_mask_i = pred_mask[i_instance] + pred_depth_i = pred_depth[i_instance] + save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(pred_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), pred_mask_i * 255) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(pred_depth_i, pred_mask_i), cv2.COLOR_RGB2BGR)) + + pbar.set_postfix({'loss': loss.item()}, refresh=False) + pbar.update(1) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/moge/scripts/vis_data.py b/moge/scripts/vis_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb21766a67e4370578acbdf7bd17d1feb46b937 --- /dev/null +++ b/moge/scripts/vis_data.py @@ -0,0 +1,84 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +import sys +from pathlib import Path +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) + +import click + + +@click.command() +@click.argument('folder_or_path', type=click.Path(exists=True)) +@click.option('--output', '-o', 'output_folder', type=click.Path(), help='Path to output folder') +@click.option('--max_depth', '-m', type=float, default=float('inf'), help='max depth') +@click.option('--fov', type=float, default=None, help='field of view in degrees') +@click.option('--show', 'show', is_flag=True, help='show point cloud') +@click.option('--depth', 'depth_filename', type=str, default='depth.png', help='depth image file name') +@click.option('--ply', 'save_ply', is_flag=True, help='save point cloud as PLY file') +@click.option('--depth_vis', 'save_depth_vis', is_flag=True, help='save depth image') +@click.option('--inf', 'inf_mask', is_flag=True, help='use infinity mask') +@click.option('--version', 'version', type=str, default='v3', help='version of rgbd data') +def main( + folder_or_path: str, + output_folder: str, + max_depth: float, + fov: float, + depth_filename: str, + show: bool, + save_ply: bool, + save_depth_vis: bool, + inf_mask: bool, + version: str +): + # Lazy import + import cv2 + import numpy as np + import utils3d + from tqdm import tqdm + import trimesh + + from moge.utils.io import read_image, read_depth, read_meta + from moge.utils.vis import colorize_depth, colorize_normal + + filepaths = sorted(p.parent for p in Path(folder_or_path).rglob('meta.json')) + + for filepath in tqdm(filepaths): + image = read_image(Path(filepath, 'image.jpg')) + depth, unit = read_depth(Path(filepath, depth_filename)) + meta = read_meta(Path(filepath,'meta.json')) + depth_mask = np.isfinite(depth) + depth_mask_inf = (depth == np.inf) + intrinsics = np.array(meta['intrinsics']) + + extrinsics = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=float) # OpenGL's identity camera + verts = utils3d.numpy.unproject_cv(utils3d.numpy.image_uv(*image.shape[:2]), depth, extrinsics=extrinsics, intrinsics=intrinsics) + + depth_mask_ply = depth_mask & (depth < depth[depth_mask].min() * max_depth) + point_cloud = trimesh.PointCloud(verts[depth_mask_ply], image[depth_mask_ply] / 255) + + if show: + point_cloud.show() + + if output_folder is None: + output_path = filepath + else: + output_path = Path(output_folder, filepath.name) + output_path.mkdir(exist_ok=True, parents=True) + + if inf_mask: + depth = np.where(depth_mask_inf, np.inf, depth) + depth_mask = depth_mask | depth_mask_inf + + if save_depth_vis: + p = output_path.joinpath('depth_vis.png') + cv2.imwrite(str(p), cv2.cvtColor(colorize_depth(depth, depth_mask), cv2.COLOR_RGB2BGR)) + print(f"{p}") + + if save_ply: + p = output_path.joinpath('pointcloud.ply') + point_cloud.export(p) + print(f"{p}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/moge/test/__init__.py b/moge/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/moge/test/baseline.py b/moge/test/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..05980aaf96870304534fcec6532225e870351a66 --- /dev/null +++ b/moge/test/baseline.py @@ -0,0 +1,43 @@ +from typing import * + +import click +import torch + + +class MGEBaselineInterface: + """ + Abstract class for model wrapper to uniformize the interface of loading and inference across different models. + """ + device: torch.device + + @click.command() + @staticmethod + def load(*args, **kwargs) -> "MGEBaselineInterface": + """ + Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()` + """ + raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.") + + def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + """ + ### Parameters + `image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1] + `intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional. + + ### Returns + A dictionary containing: + - `points_*`. point map output in OpenCV identity camera space. + Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`. + - `depth_*`. depth map output + Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`. + - `disparity_affine_invariant`. affine disparity map output + """ + raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.") + + def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + """ + If the model has a special evaluation mode, override this method to provide the evaluation mode inference. + + By default, this method simply calls `infer()`. + """ + return self.infer(image, intrinsics) \ No newline at end of file diff --git a/moge/test/dataloader.py b/moge/test/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..76679829afdf385938b604fa8bb5ef07b2560e7b --- /dev/null +++ b/moge/test/dataloader.py @@ -0,0 +1,221 @@ +import os +from typing import * +from pathlib import Path +import math + +import numpy as np +import torch +from PIL import Image +import cv2 +import utils3d + +from ..utils import pipeline +from ..utils.geometry_numpy import focal_to_fov_numpy, mask_aware_nearest_resize_numpy, norm3d +from ..utils.io import * +from ..utils.tools import timeit + + +class EvalDataLoaderPipeline: + + def __init__( + self, + path: str, + width: int, + height: int, + split: int = '.index.txt', + drop_max_depth: float = 1000., + num_load_workers: int = 4, + num_process_workers: int = 8, + include_segmentation: bool = False, + include_normal: bool = False, + depth_to_normal: bool = False, + max_segments: int = 100, + min_seg_area: int = 1000, + depth_unit: str = None, + has_sharp_boundary = False, + subset: int = None, + ): + filenames = Path(path).joinpath(split).read_text(encoding='utf-8').splitlines() + filenames = filenames[::subset] + self.width = width + self.height = height + self.drop_max_depth = drop_max_depth + self.path = Path(path) + self.filenames = filenames + self.include_segmentation = include_segmentation + self.include_normal = include_normal + self.max_segments = max_segments + self.min_seg_area = min_seg_area + self.depth_to_normal = depth_to_normal + self.depth_unit = depth_unit + self.has_sharp_boundary = has_sharp_boundary + + self.rng = np.random.default_rng(seed=0) + + self.pipeline = pipeline.Sequential([ + self._generator, + pipeline.Parallel([self._load_instance] * num_load_workers), + pipeline.Parallel([self._process_instance] * num_process_workers), + pipeline.Buffer(4) + ]) + + def __len__(self): + return math.ceil(len(self.filenames)) + + def _generator(self): + for idx in range(len(self)): + yield idx + + def _load_instance(self, idx): + if idx >= len(self.filenames): + return None + + path = self.path.joinpath(self.filenames[idx]) + + instance = { + 'filename': self.filenames[idx], + 'width': self.width, + 'height': self.height, + } + instance['image'] = read_image(Path(path, 'image.jpg')) + + depth, _ = read_depth(Path(path, 'depth.png')) # ignore depth unit from depth file, use config instead + instance.update({ + 'depth': np.nan_to_num(depth, nan=1, posinf=1, neginf=1), + 'depth_mask': np.isfinite(depth), + 'depth_mask_inf': np.isinf(depth), + }) + + if self.include_segmentation: + segmentation_mask, segmentation_labels = read_segmentation(Path(path,'segmentation.png')) + instance.update({ + 'segmentation_mask': segmentation_mask, + 'segmentation_labels': segmentation_labels, + }) + + meta = read_meta(Path(path, 'meta.json')) + instance['intrinsics'] = np.array(meta['intrinsics'], dtype=np.float32) + + return instance + + def _process_instance(self, instance: dict): + if instance is None: + return None + + image, depth, depth_mask, intrinsics = instance['image'], instance['depth'], instance['depth_mask'], instance['intrinsics'] + segmentation_mask, segmentation_labels = instance.get('segmentation_mask', None), instance.get('segmentation_labels', None) + + raw_height, raw_width = image.shape[:2] + raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) + raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height + tgt_width, tgt_height = instance['width'], instance['height'] + tgt_aspect = tgt_width / tgt_height + + # set expected target view field + tgt_horizontal = min(raw_horizontal, raw_vertical * tgt_aspect) + tgt_vertical = tgt_horizontal / tgt_aspect + + # set target view direction + cu, cv = 0.5, 0.5 + direction = utils3d.numpy.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0] + R = utils3d.numpy.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32)) + + # restrict target view field within the raw view + corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32) + corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane + corners = corners[:, :2] / corners[:, 2:3] + + warp_horizontal, warp_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) + for i in range(4): + intersection, _ = utils3d.numpy.ray_intersection( + np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]), + corners[i - 1], corners[i] - corners[i - 1], + ) + warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min()) + tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical) + + # get target view intrinsics + fx, fy = 1.0 / tgt_horizontal, 1.0 / tgt_vertical + tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32) + + # do homogeneous transformation with the rotation and intrinsics + # 4.1 The image and depth is resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling + tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes) + rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h) + image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS)) + + depth, depth_mask = mask_aware_nearest_resize_numpy(depth, depth_mask, (rescaled_w, rescaled_h)) + distance = norm3d(utils3d.numpy.depth_to_points(depth, intrinsics=intrinsics)) + segmentation_mask = cv2.resize(segmentation_mask, (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) if segmentation_mask is not None else None + + # 4.2 calculate homography warping + transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics) + uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height) + pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T + uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12) + pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32) + + tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) + tgt_distance = cv2.remap(distance, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) + tgt_ray_length = utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics) + tgt_ray_length = (tgt_ray_length[:, :, 0] ** 2 + tgt_ray_length[:, :, 1] ** 2 + tgt_ray_length[:, :, 2] ** 2) ** 0.5 + tgt_depth = tgt_distance / (tgt_ray_length + 1e-12) + tgt_depth_mask = cv2.remap(depth_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + tgt_segmentation_mask = cv2.remap(segmentation_mask, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) if segmentation_mask is not None else None + + # drop depth greater than drop_max_depth + max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.drop_max_depth + tgt_depth_mask &= tgt_depth <= max_depth + tgt_depth = np.nan_to_num(tgt_depth, nan=0.0) + + if self.depth_unit is not None: + tgt_depth *= self.depth_unit + + if not np.any(tgt_depth_mask): + # always make sure that mask is not empty, otherwise the loss calculation will crash + tgt_depth_mask = np.ones_like(tgt_depth_mask) + tgt_depth = np.ones_like(tgt_depth) + instance['label_type'] = 'invalid' + + tgt_pts = utils3d.numpy.unproject_cv(uv_tgt, tgt_depth, intrinsics=tgt_intrinsics) + + # Process segmentation labels + if self.include_segmentation and segmentation_mask is not None: + for k in ['undefined', 'unannotated', 'background', 'sky']: + if k in segmentation_labels: + del segmentation_labels[k] + seg_id2count = dict(zip(*np.unique(tgt_segmentation_mask, return_counts=True))) + sorted_labels = sorted(segmentation_labels.keys(), key=lambda x: seg_id2count.get(segmentation_labels[x], 0), reverse=True) + segmentation_labels = {k: segmentation_labels[k] for k in sorted_labels[:self.max_segments] if seg_id2count.get(segmentation_labels[k], 0) >= self.min_seg_area} + + instance.update({ + 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1), + 'depth': torch.from_numpy(tgt_depth).float(), + 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(), + 'intrinsics': torch.from_numpy(tgt_intrinsics).float(), + 'points': torch.from_numpy(tgt_pts).float(), + 'segmentation_mask': torch.from_numpy(tgt_segmentation_mask).long() if tgt_segmentation_mask is not None else None, + 'segmentation_labels': segmentation_labels, + 'is_metric': self.depth_unit is not None, + 'has_sharp_boundary': self.has_sharp_boundary, + }) + + instance = {k: v for k, v in instance.items() if v is not None} + + return instance + + def start(self): + self.pipeline.start() + + def stop(self): + self.pipeline.stop() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stop() + + def get(self): + return self.pipeline.get() \ No newline at end of file diff --git a/moge/test/metrics.py b/moge/test/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..904064f2a30d05dca3a53db7ecc076a0c2aaa0ad --- /dev/null +++ b/moge/test/metrics.py @@ -0,0 +1,343 @@ +from typing import * +from numbers import Number + +import torch +import torch.nn.functional as F +import numpy as np +import utils3d + +from ..utils.geometry_torch import ( + weighted_mean, + mask_aware_nearest_resize, + intrinsics_to_fov +) +from ..utils.alignment import ( + align_points_scale_z_shift, + align_points_scale_xyz_shift, + align_points_xyz_shift, + align_affine_lstsq, + align_depth_scale, + align_depth_affine, + align_points_scale, +) +from ..utils.tools import key_average, timeit + + +def rel_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + rel = (torch.abs(pred - gt) / (gt + eps)).mean() + return rel.item() + + +def delta1_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + delta1 = (torch.maximum(gt / pred, pred / gt) < 1.25).float().mean() + return delta1.item() + + +def rel_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + dist_gt = torch.norm(gt, dim=-1) + dist_err = torch.norm(pred - gt, dim=-1) + rel = (dist_err / (dist_gt + eps)).mean() + return rel.item() + + +def delta1_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + dist_pred = torch.norm(pred, dim=-1) + dist_gt = torch.norm(gt, dim=-1) + dist_err = torch.norm(pred - gt, dim=-1) + + delta1 = (dist_err < 0.25 * torch.minimum(dist_gt, dist_pred)).float().mean() + return delta1.item() + + +def rel_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor): + dist_err = torch.norm(pred - gt, dim=-1) + rel = (dist_err / diameter).mean() + return rel.item() + + +def delta1_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor): + dist_err = torch.norm(pred - gt, dim=-1) + delta1 = (dist_err < 0.25 * diameter).float().mean() + return delta1.item() + + +def boundary_f1(pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, radius: int = 1): + neighbor_x, neight_y = torch.meshgrid( + torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device), + torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device), + indexing='xy' + ) + neighbor_mask = (neighbor_x ** 2 + neight_y ** 2) <= radius ** 2 + 1e-5 + + pred_window = utils3d.torch.sliding_window_2d(pred, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1] + gt_window = utils3d.torch.sliding_window_2d(gt, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1] + mask_window = neighbor_mask & utils3d.torch.sliding_window_2d(mask, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1] + + pred_rel = pred_window / pred[radius:-radius, radius:-radius, None, None] + gt_rel = gt_window / gt[radius:-radius, radius:-radius, None, None] + valid = mask[radius:-radius, radius:-radius, None, None] & mask_window + + f1_list = [] + w_list = t_list = torch.linspace(0.05, 0.25, 10).tolist() + + for t in t_list: + pred_label = pred_rel > 1 + t + gt_label = gt_rel > 1 + t + TP = (pred_label & gt_label & valid).float().sum() + precision = TP / (gt_label & valid).float().sum().clamp_min(1e-12) + recall = TP / (pred_label & valid).float().sum().clamp_min(1e-12) + f1 = 2 * precision * recall / (precision + recall).clamp_min(1e-12) + f1_list.append(f1.item()) + + f1_avg = sum(w * f1 for w, f1 in zip(w_list, f1_list)) / sum(w_list) + return f1_avg + + +def compute_metrics( + pred: Dict[str, torch.Tensor], + gt: Dict[str, torch.Tensor], + vis: bool = False +) -> Tuple[Dict[str, Dict[str, Number]], Dict[str, torch.Tensor]]: + """ + A unified function to compute metrics for different types of predictions and ground truths. + + #### Supported keys in pred: + - `disparity_affine_invariant`: disparity map predicted by a depth estimator with scale and shift invariant. + - `depth_scale_invariant`: depth map predicted by a depth estimator with scale invariant. + - `depth_affine_invariant`: depth map predicted by a depth estimator with scale and shift invariant. + - `depth_metric`: depth map predicted by a depth estimator with no scale or shift. + - `points_scale_invariant`: point map predicted by a point estimator with scale invariant. + - `points_affine_invariant`: point map predicted by a point estimator with scale and xyz shift invariant. + - `points_metric`: point map predicted by a point estimator with no scale or shift. + - `intrinsics`: normalized camera intrinsics matrix. + + #### Required keys in gt: + - `depth`: depth map ground truth (in metric units if `depth_metric` is used) + - `points`: point map ground truth in camera coordinates. + - `mask`: mask indicating valid pixels in the ground truth. + - `intrinsics`: normalized ground-truth camera intrinsics matrix. + - `is_metric`: whether the depth is in metric units. + """ + metrics = {} + misc = {} + + mask = gt['depth_mask'] + gt_depth = gt['depth'] + gt_points = gt['points'] + + height, width = mask.shape[-2:] + _, lr_mask, lr_index = mask_aware_nearest_resize(None, mask, (64, 64), return_index=True) + + only_depth = not any('point' in k for k in pred) + pred_depth_aligned, pred_points_aligned = None, None + + # Metric depth + if 'depth_metric' in pred and gt['is_metric']: + pred_depth, gt_depth = pred['depth_metric'], gt['depth'] + metrics['depth_metric'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = pred_depth + + # Scale-invariant depth + if 'depth_scale_invariant' in pred: + pred_depth_scale_invariant = pred['depth_scale_invariant'] + elif 'depth_metric' in pred: + pred_depth_scale_invariant = pred['depth_metric'] + else: + pred_depth_scale_invariant = None + + if pred_depth_scale_invariant is not None: + pred_depth = pred_depth_scale_invariant + + pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask] + scale = align_depth_scale(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked) + pred_depth = pred_depth * scale + + metrics['depth_scale_invariant'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = pred_depth + + # Affine-invariant depth + if 'depth_affine_invariant' in pred: + pred_depth_affine_invariant = pred['depth_affine_invariant'] + elif 'depth_scale_invariant' in pred: + pred_depth_affine_invariant = pred['depth_scale_invariant'] + elif 'depth_metric' in pred: + pred_depth_affine_invariant = pred['depth_metric'] + else: + pred_depth_affine_invariant = None + + if pred_depth_affine_invariant is not None: + pred_depth = pred_depth_affine_invariant + + pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask] + scale, shift = align_depth_affine(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked) + pred_depth = pred_depth * scale + shift + + metrics['depth_affine_invariant'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = pred_depth + + # Affine-invariant disparity + if 'disparity_affine_invariant' in pred: + pred_disparity_affine_invariant = pred['disparity_affine_invariant'] + elif 'depth_scale_invariant' in pred: + pred_disparity_affine_invariant = 1 / pred['depth_scale_invariant'] + elif 'depth_metric' in pred: + pred_disparity_affine_invariant = 1 / pred['depth_metric'] + else: + pred_disparity_affine_invariant = None + + if pred_disparity_affine_invariant is not None: + pred_disp = pred_disparity_affine_invariant + + scale, shift = align_affine_lstsq(pred_disp[mask], 1 / gt_depth[mask]) + pred_disp = pred_disp * scale + shift + + # NOTE: The alignment is done on the disparity map could introduce extreme outliers at disparities close to 0. + # Therefore we clamp the disparities by minimum ground truth disparity. + pred_depth = 1 / pred_disp.clamp_min(1 / gt_depth[mask].max().item()) + + metrics['disparity_affine_invariant'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = 1 / pred_disp.clamp_min(1e-6) + + # Metric points + if 'points_metric' in pred and gt['is_metric']: + pred_points = pred['points_metric'] + + pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask] + shift = align_points_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1)) + pred_points = pred_points + shift + + metrics['points_metric'] = { + 'rel': rel_point(pred_points[mask], gt_points[mask]), + 'delta1': delta1_point(pred_points[mask], gt_points[mask]) + } + + if pred_points_aligned is None: + pred_points_aligned = pred['points_metric'] + + # Scale-invariant points (in camera space) + if 'points_scale_invariant' in pred: + pred_points_scale_invariant = pred['points_scale_invariant'] + elif 'points_metric' in pred: + pred_points_scale_invariant = pred['points_metric'] + else: + pred_points_scale_invariant = None + + if pred_points_scale_invariant is not None: + pred_points = pred_points_scale_invariant + + pred_points_lr_masked, gt_points_lr_masked = pred_points_scale_invariant[lr_index][lr_mask], gt_points[lr_index][lr_mask] + scale = align_points_scale(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1)) + pred_points = pred_points * scale + + metrics['points_scale_invariant'] = { + 'rel': rel_point(pred_points[mask], gt_points[mask]), + 'delta1': delta1_point(pred_points[mask], gt_points[mask]) + } + + if vis and pred_points_aligned is None: + pred_points_aligned = pred['points_scale_invariant'] * scale + + # Affine-invariant points + if 'points_affine_invariant' in pred: + pred_points_affine_invariant = pred['points_affine_invariant'] + elif 'points_scale_invariant' in pred: + pred_points_affine_invariant = pred['points_scale_invariant'] + elif 'points_metric' in pred: + pred_points_affine_invariant = pred['points_metric'] + else: + pred_points_affine_invariant = None + + if pred_points_affine_invariant is not None: + pred_points = pred_points_affine_invariant + + pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask] + scale, shift = align_points_scale_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1)) + pred_points = pred_points * scale + shift + + metrics['points_affine_invariant'] = { + 'rel': rel_point(pred_points[mask], gt_points[mask]), + 'delta1': delta1_point(pred_points[mask], gt_points[mask]) + } + + if vis and pred_points_aligned is None: + pred_points_aligned = pred['points_affine_invariant'] * scale + shift + + # Local points + if 'segmentation_mask' in gt and 'points' in gt and any('points' in k for k in pred.keys()): + pred_points = next(pred[k] for k in pred.keys() if 'points' in k) + gt_points = gt['points'] + segmentation_mask = gt['segmentation_mask'] + segmentation_labels = gt['segmentation_labels'] + segmentation_mask_lr = segmentation_mask[lr_index] + local_points_metrics = [] + for _, seg_id in segmentation_labels.items(): + valid_mask = (segmentation_mask == seg_id) & mask + + pred_points_masked = pred_points[valid_mask] + gt_points_masked = gt_points[valid_mask] + + valid_mask_lr = (segmentation_mask_lr == seg_id) & lr_mask + if valid_mask_lr.sum().item() < 10: + continue + pred_points_masked_lr = pred_points[lr_index][valid_mask_lr] + gt_points_masked_lr = gt_points[lr_index][valid_mask_lr] + diameter = (gt_points_masked.max(dim=0).values - gt_points_masked.min(dim=0).values).max() + scale, shift = align_points_scale_xyz_shift(pred_points_masked_lr, gt_points_masked_lr, 1 / diameter.expand(gt_points_masked_lr.shape[0])) + pred_points_masked = pred_points_masked * scale + shift + + local_points_metrics.append({ + 'rel': rel_point_local(pred_points_masked, gt_points_masked, diameter), + 'delta1': delta1_point_local(pred_points_masked, gt_points_masked, diameter), + }) + + metrics['local_points'] = key_average(local_points_metrics) + + # FOV. NOTE: If there is no random augmentation applied to the input images, all GT FOV are generallly the same. + # Fair evaluation of FOV requires random augmentation. + if 'intrinsics' in pred and 'intrinsics' in gt: + pred_intrinsics = pred['intrinsics'] + gt_intrinsics = gt['intrinsics'] + pred_fov_x, pred_fov_y = intrinsics_to_fov(pred_intrinsics) + gt_fov_x, gt_fov_y = intrinsics_to_fov(gt_intrinsics) + metrics['fov_x'] = { + 'mae': torch.rad2deg(pred_fov_x - gt_fov_x).abs().mean().item(), + 'deviation': torch.rad2deg(pred_fov_x - gt_fov_x).item(), + } + + # Boundary F1 + if pred_depth_aligned is not None and gt['has_sharp_boundary']: + metrics['boundary'] = { + 'radius1_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=1), + 'radius2_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=2), + 'radius3_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=3), + } + + if vis: + if pred_points_aligned is not None: + misc['pred_points'] = pred_points_aligned + if only_depth: + misc['pred_points'] = utils3d.torch.depth_to_points(pred_depth_aligned, intrinsics=gt['intrinsics']) + if pred_depth_aligned is not None: + misc['pred_depth'] = pred_depth_aligned + + return metrics, misc \ No newline at end of file diff --git a/moge/train/__init__.py b/moge/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/moge/train/dataloader.py b/moge/train/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bfc280844dac602e89bee747e247946dbc6f67 --- /dev/null +++ b/moge/train/dataloader.py @@ -0,0 +1,338 @@ +import os +from pathlib import Path +import json +import time +import random +from typing import * +import traceback +import itertools +from numbers import Number +import io + +import numpy as np +import cv2 +from PIL import Image +import torch +import torchvision.transforms.v2.functional as TF +import utils3d +from tqdm import tqdm + +from ..utils import pipeline +from ..utils.io import * +from ..utils.geometry_numpy import mask_aware_nearest_resize_numpy, harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy, depth_of_field + + +class TrainDataLoaderPipeline: + def __init__(self, config: dict, batch_size: int, num_load_workers: int = 4, num_process_workers: int = 8, buffer_size: int = 8): + self.config = config + + self.batch_size = batch_size + self.clamp_max_depth = config['clamp_max_depth'] + self.fov_range_absolute = config.get('fov_range_absolute', 0.0) + self.fov_range_relative = config.get('fov_range_relative', 0.0) + self.center_augmentation = config.get('center_augmentation', 0.0) + self.image_augmentation = config.get('image_augmentation', []) + self.depth_interpolation = config.get('depth_interpolation', 'bilinear') + + if 'image_sizes' in config: + self.image_size_strategy = 'fixed' + self.image_sizes = config['image_sizes'] + elif 'aspect_ratio_range' in config and 'area_range' in config: + self.image_size_strategy = 'aspect_area' + self.aspect_ratio_range = config['aspect_ratio_range'] + self.area_range = config['area_range'] + else: + raise ValueError('Invalid image size configuration') + + # Load datasets + self.datasets = {} + for dataset in tqdm(config['datasets'], desc='Loading datasets'): + name = dataset['name'] + content = Path(dataset['path'], dataset.get('index', '.index.txt')).joinpath().read_text() + filenames = content.splitlines() + self.datasets[name] = { + **dataset, + 'path': dataset['path'], + 'filenames': filenames, + } + self.dataset_names = [dataset['name'] for dataset in config['datasets']] + self.dataset_weights = [dataset['weight'] for dataset in config['datasets']] + + # Build pipeline + self.pipeline = pipeline.Sequential([ + self._sample_batch, + pipeline.Unbatch(), + pipeline.Parallel([self._load_instance] * num_load_workers), + pipeline.Parallel([self._process_instance] * num_process_workers), + pipeline.Batch(self.batch_size), + self._collate_batch, + pipeline.Buffer(buffer_size), + ]) + + self.invalid_instance = { + 'intrinsics': np.array([[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]], dtype=np.float32), + 'image': np.zeros((256, 256, 3), dtype=np.uint8), + 'depth': np.ones((256, 256), dtype=np.float32), + 'depth_mask': np.ones((256, 256), dtype=bool), + 'depth_mask_inf': np.zeros((256, 256), dtype=bool), + 'label_type': 'invalid', + } + + def _sample_batch(self): + batch_id = 0 + last_area = None + while True: + # Depending on the sample strategy, choose a dataset and a filename + batch_id += 1 + batch = [] + + # Sample instances + for _ in range(self.batch_size): + dataset_name = random.choices(self.dataset_names, weights=self.dataset_weights)[0] + filename = random.choice(self.datasets[dataset_name]['filenames']) + + path = Path(self.datasets[dataset_name]['path'], filename) + + instance = { + 'batch_id': batch_id, + 'seed': random.randint(0, 2 ** 32 - 1), + 'dataset': dataset_name, + 'filename': filename, + 'path': path, + 'label_type': self.datasets[dataset_name]['label_type'], + } + batch.append(instance) + + # Decide the image size for this batch + if self.image_size_strategy == 'fixed': + width, height = random.choice(self.config['image_sizes']) + elif self.image_size_strategy == 'aspect_area': + area = random.uniform(*self.area_range) + aspect_ratio_ranges = [self.datasets[instance['dataset']].get('aspect_ratio_range', self.aspect_ratio_range) for instance in batch] + aspect_ratio_range = (min(r[0] for r in aspect_ratio_ranges), max(r[1] for r in aspect_ratio_ranges)) + aspect_ratio = random.uniform(*aspect_ratio_range) + width, height = int((area * aspect_ratio) ** 0.5), int((area / aspect_ratio) ** 0.5) + else: + raise ValueError('Invalid image size strategy') + + for instance in batch: + instance['width'], instance['height'] = width, height + + yield batch + + def _load_instance(self, instance: dict): + try: + image = read_image(Path(instance['path'], 'image.jpg')) + depth, _ = read_depth(Path(instance['path'], self.datasets[instance['dataset']].get('depth', 'depth.png'))) + + meta = read_meta(Path(instance['path'], 'meta.json')) + intrinsics = np.array(meta['intrinsics'], dtype=np.float32) + depth_mask = np.isfinite(depth) + depth_mask_inf = np.isinf(depth) + depth = np.nan_to_num(depth, nan=1, posinf=1, neginf=1) + data = { + 'image': image, + 'depth': depth, + 'depth_mask': depth_mask, + 'depth_mask_inf': depth_mask_inf, + 'intrinsics': intrinsics + } + instance.update({ + **data, + }) + except Exception as e: + print(f"Failed to load instance {instance['dataset']}/{instance['filename']} because of exception:", e) + instance.update(self.invalid_instance) + return instance + + def _process_instance(self, instance: Dict[str, Union[np.ndarray, str, float, bool]]): + image, depth, depth_mask, depth_mask_inf, intrinsics, label_type = instance['image'], instance['depth'], instance['depth_mask'], instance['depth_mask_inf'], instance['intrinsics'], instance['label_type'] + depth_unit = self.datasets[instance['dataset']].get('depth_unit', None) + + raw_height, raw_width = image.shape[:2] + raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) + raw_fov_x, raw_fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) + raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height + tgt_width, tgt_height = instance['width'], instance['height'] + tgt_aspect = tgt_width / tgt_height + + rng = np.random.default_rng(instance['seed']) + + # 1. set target fov + center_augmentation = self.datasets[instance['dataset']].get('center_augmentation', self.center_augmentation) + fov_range_absolute_min, fov_range_absolute_max = self.datasets[instance['dataset']].get('fov_range_absolute', self.fov_range_absolute) + fov_range_relative_min, fov_range_relative_max = self.datasets[instance['dataset']].get('fov_range_relative', self.fov_range_relative) + tgt_fov_x_min = min(fov_range_relative_min * raw_fov_x, fov_range_relative_min * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect)) + tgt_fov_x_max = min(fov_range_relative_max * raw_fov_x, fov_range_relative_max * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect)) + tgt_fov_x_min, tgt_fov_x_max = max(np.deg2rad(fov_range_absolute_min), tgt_fov_x_min), min(np.deg2rad(fov_range_absolute_max), tgt_fov_x_max) + tgt_fov_x = rng.uniform(min(tgt_fov_x_min, tgt_fov_x_max), tgt_fov_x_max) + tgt_fov_y = utils3d.focal_to_fov(utils3d.numpy.fov_to_focal(tgt_fov_x) * tgt_aspect) + + # 2. set target image center (principal point) and the corresponding z-direction in raw camera space + center_dtheta = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_x - tgt_fov_x) + center_dphi = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_y - tgt_fov_y) + cu, cv = 0.5 + 0.5 * np.tan(center_dtheta) / np.tan(raw_fov_x / 2), 0.5 + 0.5 * np.tan(center_dphi) / np.tan(raw_fov_y / 2) + direction = utils3d.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0] + + # 3. obtain the rotation matrix for homography warping + R = utils3d.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32)) + + # 4. shrink the target view to fit into the warped image + corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32) + corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane + corners = corners[:, :2] / corners[:, 2:3] + tgt_horizontal, tgt_vertical = np.tan(tgt_fov_x / 2) * 2, np.tan(tgt_fov_y / 2) * 2 + warp_horizontal, warp_vertical = float('inf'), float('inf') + for i in range(4): + intersection, _ = utils3d.numpy.ray_intersection( + np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]), + corners[i - 1], corners[i] - corners[i - 1], + ) + warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min()) + tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical) + + # 5. obtain the target intrinsics + fx, fy = 1 / tgt_horizontal, 1 / tgt_vertical + tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32) + + # 6. do homogeneous transformation + # 6.1 The image and depth are resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling + tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes) + rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h) + image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS)) + + edge_mask = depth_occlusion_edge_numpy(depth, mask=depth_mask, thickness=2, tol=0.01) + _, depth_mask_nearest, resize_index = mask_aware_nearest_resize_numpy(None, depth_mask, (rescaled_w, rescaled_h), return_index=True) + depth_nearest = depth[resize_index] + distance_nearest = norm3d(utils3d.numpy.depth_to_points(depth_nearest, intrinsics=intrinsics)) + edge_mask = edge_mask[resize_index] + + if self.depth_interpolation == 'bilinear': + depth_mask_bilinear = cv2.resize(depth_mask.astype(np.float32), (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR) + depth_bilinear = 1 / cv2.resize(1 / depth, (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR) + distance_bilinear = norm3d(utils3d.numpy.depth_to_points(depth_bilinear, intrinsics=intrinsics)) + + depth_mask_inf = cv2.resize(depth_mask_inf.astype(np.uint8), (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) > 0 + + # 6.2 calculate homography warping + transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics) + uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height) + pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T + uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12) + pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32) + + tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LANCZOS4) + tgt_ray_length = norm3d(utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics)) + tgt_depth_mask_nearest = cv2.remap(depth_mask_nearest.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + tgt_depth_nearest = cv2.remap(distance_nearest, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) / tgt_ray_length + tgt_edge_mask = cv2.remap(edge_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + if self.depth_interpolation == 'bilinear': + tgt_depth_mask_bilinear = cv2.remap(depth_mask_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) + tgt_depth_bilinear = cv2.remap(distance_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) / tgt_ray_length + tgt_depth = np.where((tgt_depth_mask_bilinear == 1) & ~tgt_edge_mask, tgt_depth_bilinear, tgt_depth_nearest) + else: + tgt_depth = tgt_depth_nearest + tgt_depth_mask = tgt_depth_mask_nearest + + tgt_depth_mask_inf = cv2.remap(depth_mask_inf.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + + # always make sure that mask is not empty + if tgt_depth_mask.sum() / tgt_depth_mask.size < 0.001: + tgt_depth_mask = np.ones_like(tgt_depth_mask) + tgt_depth = np.ones_like(tgt_depth) + instance['label_type'] = 'invalid' + + # Flip augmentation + if rng.choice([True, False]): + tgt_image = np.flip(tgt_image, axis=1).copy() + tgt_depth = np.flip(tgt_depth, axis=1).copy() + tgt_depth_mask = np.flip(tgt_depth_mask, axis=1).copy() + tgt_depth_mask_inf = np.flip(tgt_depth_mask_inf, axis=1).copy() + + # Color augmentation + image_augmentation = self.datasets[instance['dataset']].get('image_augmentation', self.image_augmentation) + if 'jittering' in image_augmentation: + tgt_image = torch.from_numpy(tgt_image).permute(2, 0, 1) + tgt_image = TF.adjust_brightness(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = TF.adjust_contrast(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = TF.adjust_saturation(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = TF.adjust_hue(tgt_image, rng.uniform(-0.1, 0.1)) + tgt_image = TF.adjust_gamma(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = tgt_image.permute(1, 2, 0).numpy() + if 'dof' in image_augmentation: + if rng.uniform() < 0.5: + dof_strength = rng.integers(12) + tgt_disp = np.where(tgt_depth_mask_inf, 0, 1 / tgt_depth) + disp_min, disp_max = tgt_disp[tgt_depth_mask].min(), tgt_disp[tgt_depth_mask].max() + tgt_disp = cv2.inpaint(tgt_disp, (~tgt_depth_mask & ~tgt_depth_mask_inf).astype(np.uint8), 3, cv2.INPAINT_TELEA).clip(disp_min, disp_max) + dof_focus = rng.uniform(disp_min, disp_max) + tgt_image = depth_of_field(tgt_image, tgt_disp, dof_focus, dof_strength) + if 'shot_noise' in image_augmentation: + if rng.uniform() < 0.5: + k = np.exp(rng.uniform(np.log(100), np.log(10000))) / 255 + tgt_image = (rng.poisson(tgt_image * k) / k).clip(0, 255).astype(np.uint8) + if 'jpeg_loss' in image_augmentation: + if rng.uniform() < 0.5: + tgt_image = cv2.imdecode(cv2.imencode('.jpg', tgt_image, [cv2.IMWRITE_JPEG_QUALITY, rng.integers(20, 100)])[1], cv2.IMREAD_COLOR) + if 'blurring' in image_augmentation: + if rng.uniform() < 0.5: + ratio = rng.uniform(0.25, 1) + tgt_image = cv2.resize(cv2.resize(tgt_image, (int(tgt_width * ratio), int(tgt_height * ratio)), interpolation=cv2.INTER_AREA), (tgt_width, tgt_height), interpolation=rng.choice([cv2.INTER_LINEAR_EXACT, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])) + + # convert depth to metric if necessary + if depth_unit is not None: + tgt_depth *= depth_unit + instance['is_metric'] = True + else: + instance['is_metric'] = False + + # clamp depth maximum values + max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.clamp_max_depth + tgt_depth = np.clip(tgt_depth, 0, max_depth) + tgt_depth = np.nan_to_num(tgt_depth, nan=1.0) + + if self.datasets[instance['dataset']].get('finite_depth_mask', None) == "only_known": + tgt_depth_mask_fin = tgt_depth_mask + else: + tgt_depth_mask_fin = ~tgt_depth_mask_inf + + instance.update({ + 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1), + 'depth': torch.from_numpy(tgt_depth).float(), + 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(), + 'depth_mask_fin': torch.from_numpy(tgt_depth_mask_fin).bool(), + 'depth_mask_inf': torch.from_numpy(tgt_depth_mask_inf).bool(), + 'intrinsics': torch.from_numpy(tgt_intrinsics).float(), + }) + + return instance + + def _collate_batch(self, instances: List[Dict[str, Any]]): + batch = {k: torch.stack([instance[k] for instance in instances], dim=0) for k in ['image', 'depth', 'depth_mask', 'depth_mask_fin', 'depth_mask_inf', 'intrinsics']} + batch = { + 'label_type': [instance['label_type'] for instance in instances], + 'is_metric': [instance['is_metric'] for instance in instances], + 'info': [{'dataset': instance['dataset'], 'filename': instance['filename']} for instance in instances], + **batch, + } + return batch + + def get(self) -> Dict[str, Union[torch.Tensor, str]]: + return self.pipeline.get() + + def start(self): + self.pipeline.start() + + def stop(self): + self.pipeline.stop() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.pipeline.terminate() + self.pipeline.join() + return False + + diff --git a/moge/train/losses.py b/moge/train/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..2b251b230f4cc86d8358f613acf483badfb49e14 --- /dev/null +++ b/moge/train/losses.py @@ -0,0 +1,270 @@ +from typing import * +import math + +import torch +import torch.nn.functional as F +import utils3d + +from ..utils.geometry_torch import ( + weighted_mean, + harmonic_mean, + geometric_mean, + mask_aware_nearest_resize, + normalized_view_plane_uv, + angle_diff_vec3 +) +from ..utils.alignment import ( + align_points_scale_z_shift, + align_points_scale, + align_points_scale_xyz_shift, + align_points_z_shift, +) + + +def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor: + if beta == 0: + return err + else: + return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta) + + +def affine_invariant_global_loss( + pred_points: torch.Tensor, + gt_points: torch.Tensor, + mask: torch.Tensor, + align_resolution: int = 64, + beta: float = 0.0, + trunc: float = 1.0, + sparsity_aware: bool = False +): + device = pred_points.device + + # Align + (pred_points_lr, gt_points_lr), lr_mask = mask_aware_nearest_resize((pred_points, gt_points), mask=mask, size=(align_resolution, align_resolution)) + scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc) + valid = scale > 0 + scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0) + + pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :] + + # Compute loss + weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5) + weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values + loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) + + if sparsity_aware: + # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1. + sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1)) + loss = loss / (sparsity + 1e-7) + + err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2] + + # Record any scalar metric + misc = { + 'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(), + 'delta': weighted_mean((err < 1).float(), mask).item() + } + + return loss, misc, scale.detach() + + +def monitoring(points: torch.Tensor): + return { + 'std': points.std().item(), + } + + +def compute_anchor_sampling_weight( + points: torch.Tensor, + mask: torch.Tensor, + radius_2d: torch.Tensor, + radius_3d: torch.Tensor, + num_test: int = 64 +) -> torch.Tensor: + # Importance sampling to balance the sampled probability of fine strutures. + # NOTE: MoGe-1 uses uniform random sampling instead of importance sampling. + # This is an incremental trick introduced later than the publication of MoGe-1 paper. + + height, width = points.shape[-3:-1] + + pixel_i, pixel_j = torch.meshgrid( + torch.arange(height, device=points.device), + torch.arange(width, device=points.device), + indexing='ij' + ) + + test_delta_i = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test] + test_delta_j = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test] + test_i, test_j = pixel_i[..., None] + test_delta_i, pixel_j[..., None] + test_delta_j # [height, width, num_test] + test_mask = (test_i >= 0) & (test_i < height) & (test_j >= 0) & (test_j < width) # [height, width, num_test] + test_i, test_j = test_i.clamp(0, height - 1), test_j.clamp(0, width - 1) # [height, width, num_test] + test_mask = test_mask & mask[..., test_i, test_j] # [..., height, width, num_test] + test_points = points[..., test_i, test_j, :] # [..., height, width, num_test, 3] + test_dist = (test_points - points[..., None, :]).norm(dim=-1) # [..., height, width, num_test] + + weight = 1 / ((test_dist <= radius_3d[..., None]) & test_mask).float().sum(dim=-1).clamp_min(1) + weight = torch.where(mask, weight, 0) + weight = weight / weight.sum(dim=(-2, -1), keepdim=True).add(1e-7) # [..., height, width] + return weight + + +def affine_invariant_local_loss( + pred_points: torch.Tensor, + gt_points: torch.Tensor, + gt_mask: torch.Tensor, + focal: torch.Tensor, + global_scale: torch.Tensor, + level: Literal[4, 16, 64], + align_resolution: int = 32, + num_patches: int = 16, + beta: float = 0.0, + trunc: float = 1.0, + sparsity_aware: bool = False +): + device, dtype = pred_points.device, pred_points.dtype + *batch_shape, height, width, _ = pred_points.shape + batch_size = math.prod(batch_shape) + pred_points, gt_points, gt_mask, focal, global_scale = pred_points.reshape(-1, height, width, 3), gt_points.reshape(-1, height, width, 3), gt_mask.reshape(-1, height, width), focal.reshape(-1), global_scale.reshape(-1) if global_scale is not None else None + + # Sample patch anchor points indices [num_total_patches] + radius_2d = math.ceil(0.5 / level * (height ** 2 + width ** 2) ** 0.5) + radius_3d = 0.5 / level / focal * gt_points[..., 2] + anchor_sampling_weights = compute_anchor_sampling_weight(gt_points, gt_mask, radius_2d, radius_3d, num_test=64) + where_mask = torch.where(gt_mask) + random_selection = torch.multinomial(anchor_sampling_weights[where_mask], num_patches * batch_size, replacement=True) + patch_batch_idx, patch_anchor_i, patch_anchor_j = [indices[random_selection] for indices in where_mask] # [num_total_patches] + + # Get patch indices [num_total_patches, patch_h, patch_w] + patch_i, patch_j = torch.meshgrid( + torch.arange(-radius_2d, radius_2d + 1, device=device), + torch.arange(-radius_2d, radius_2d + 1, device=device), + indexing='ij' + ) + patch_i, patch_j = patch_i + patch_anchor_i[:, None, None], patch_j + patch_anchor_j[:, None, None] + patch_mask = (patch_i >= 0) & (patch_i < height) & (patch_j >= 0) & (patch_j < width) + patch_i, patch_j = patch_i.clamp(0, height - 1), patch_j.clamp(0, width - 1) + + # Get patch mask and gt patch points + gt_patch_anchor_points = gt_points[patch_batch_idx, patch_anchor_i, patch_anchor_j] + gt_patch_radius_3d = 0.5 / level / focal[patch_batch_idx] * gt_patch_anchor_points[:, 2] + gt_patch_points = gt_points[patch_batch_idx[:, None, None], patch_i, patch_j] + gt_patch_dist = (gt_patch_points - gt_patch_anchor_points[:, None, None, :]).norm(dim=-1) + patch_mask &= gt_mask[patch_batch_idx[:, None, None], patch_i, patch_j] + patch_mask &= gt_patch_dist <= gt_patch_radius_3d[:, None, None] + + # Pick only non-empty patches + MINIMUM_POINTS_PER_PATCH = 32 + nonempty = torch.where(patch_mask.sum(dim=(-2, -1)) >= MINIMUM_POINTS_PER_PATCH) + num_nonempty_patches = nonempty[0].shape[0] + if num_nonempty_patches == 0: + return torch.tensor(0.0, dtype=dtype, device=device), {} + + # Finalize all patch variables + patch_batch_idx, patch_i, patch_j = patch_batch_idx[nonempty], patch_i[nonempty], patch_j[nonempty] + patch_mask = patch_mask[nonempty] # [num_nonempty_patches, patch_h, patch_w] + gt_patch_points = gt_patch_points[nonempty] # [num_nonempty_patches, patch_h, patch_w, 3] + gt_patch_radius_3d = gt_patch_radius_3d[nonempty] # [num_nonempty_patches] + gt_patch_anchor_points = gt_patch_anchor_points[nonempty] # [num_nonempty_patches, 3] + pred_patch_points = pred_points[patch_batch_idx[:, None, None], patch_i, patch_j] + + # Align patch points + (pred_patch_points_lr, gt_patch_points_lr), patch_lr_mask = mask_aware_nearest_resize((pred_patch_points, gt_patch_points), mask=patch_mask, size=(align_resolution, align_resolution)) + local_scale, local_shift = align_points_scale_xyz_shift(pred_patch_points_lr.flatten(-3, -2), gt_patch_points_lr.flatten(-3, -2), patch_lr_mask.flatten(-2) / gt_patch_radius_3d[:, None].add(1e-7), trunc=trunc) + if global_scale is not None: + scale_differ = local_scale / global_scale[patch_batch_idx] + patch_valid = (scale_differ > 0.1) & (scale_differ < 10.0) & (global_scale > 0) + else: + patch_valid = local_scale > 0 + local_scale, local_shift = torch.where(patch_valid, local_scale, 0), torch.where(patch_valid[:, None], local_shift, 0) + patch_mask &= patch_valid[:, None, None] + + pred_patch_points = local_scale[:, None, None, None] * pred_patch_points + local_shift[:, None, None, :] # [num_patches_nonempty, patch_h, patch_w, 3] + + # Compute loss + gt_mean = harmonic_mean(gt_points[..., 2], gt_mask, dim=(-2, -1)) + patch_weight = patch_mask.float() / gt_patch_points[..., 2].clamp_min(0.1 * gt_mean[patch_batch_idx, None, None]) # [num_patches_nonempty, patch_h, patch_w] + loss = _smooth((pred_patch_points - gt_patch_points).abs() * patch_weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) # [num_patches_nonempty] + + if sparsity_aware: + # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1. + sparsity = patch_mask.float().mean(dim=(-2, -1)) / patch_lr_mask.float().mean(dim=(-2, -1)) + loss = loss / (sparsity + 1e-7) + loss = torch.scatter_reduce(torch.zeros(batch_size, dtype=dtype, device=device), dim=0, index=patch_batch_idx, src=loss, reduce='sum') / num_patches + loss = loss.reshape(batch_shape) + + err = (pred_patch_points.detach() - gt_patch_points).norm(dim=-1) / gt_patch_radius_3d[..., None, None] + + # Record any scalar metric + misc = { + 'truncated_error': weighted_mean(err.clamp_max(1), patch_mask).item(), + 'delta': weighted_mean((err < 1).float(), patch_mask).item() + } + + return loss, misc + +def normal_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + device, dtype = points.device, points.dtype + height, width = points.shape[-3:-1] + + leftup, rightup, leftdown, rightdown = points[..., :-1, :-1, :], points[..., :-1, 1:, :], points[..., 1:, :-1, :], points[..., 1:, 1:, :] + upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1) + leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1) + downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1) + rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1) + + gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = gt_points[..., :-1, :-1, :], gt_points[..., :-1, 1:, :], gt_points[..., 1:, :-1, :], gt_points[..., 1:, 1:, :] + gt_upxleft = torch.cross(gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1) + gt_leftxdown = torch.cross(gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1) + gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1) + gt_rightxup = torch.cross(gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1) + + mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = mask[..., :-1, :-1], mask[..., :-1, 1:], mask[..., 1:, :-1], mask[..., 1:, 1:] + mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown + mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup + mask_downxright = mask_leftdown & mask_rightup & mask_leftup + mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown + + MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3) + + loss = mask_upxleft * _smooth(angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \ + + mask_leftxdown * _smooth(angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \ + + mask_downxright * _smooth(angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \ + + mask_rightxup * _smooth(angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) + + loss = loss.mean() / (4 * max(points.shape[-3:-1])) + + return loss, {} + + +def edge_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + device, dtype = points.device, points.dtype + height, width = points.shape[-3:-1] + + dx = points[..., :-1, :, :] - points[..., 1:, :, :] + dy = points[..., :, :-1, :] - points[..., :, 1:, :] + + gt_dx = gt_points[..., :-1, :, :] - gt_points[..., 1:, :, :] + gt_dy = gt_points[..., :, :-1, :] - gt_points[..., :, 1:, :] + + mask_dx = mask[..., :-1, :] & mask[..., 1:, :] + mask_dy = mask[..., :, :-1] & mask[..., :, 1:] + + MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(0.1), math.radians(90), math.radians(3) + + loss_dx = mask_dx * _smooth(angle_diff_vec3(dx, gt_dx).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) + loss_dy = mask_dy * _smooth(angle_diff_vec3(dy, gt_dy).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) + loss = (loss_dx.mean(dim=(-2, -1)) + loss_dy.mean(dim=(-2, -1))) / (2 * max(points.shape[-3:-1])) + + return loss, {} + + +def mask_l2_loss(pred_mask: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor: + loss = gt_mask_neg.float() * pred_mask.square() + gt_mask_pos.float() * (1 - pred_mask).square() + loss = loss.mean(dim=(-2, -1)) + return loss, {} + + +def mask_bce_loss(pred_mask_prob: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor: + loss = (gt_mask_pos | gt_mask_neg) * F.binary_cross_entropy(pred_mask_prob, gt_mask_pos.float(), reduction='none') + loss = loss.mean(dim=(-2, -1)) + return loss, {} diff --git a/moge/train/utils.py b/moge/train/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f21e00876b927991381bf2f777a68b02c5b38cc --- /dev/null +++ b/moge/train/utils.py @@ -0,0 +1,57 @@ +from typing import * +import fnmatch + +import sympy +import torch +import torch.nn as nn + + +def any_match(s: str, patterns: List[str]) -> bool: + return any(fnmatch.fnmatch(s, pat) for pat in patterns) + + +def build_optimizer(model: nn.Module, optimizer_config: Dict[str, Any]) -> torch.optim.Optimizer: + named_param_groups = [ + { + k: p for k, p in model.named_parameters() if any_match(k, param_group_config['params']['include']) and not any_match(k, param_group_config['params'].get('exclude', [])) + } for param_group_config in optimizer_config['params'] + ] + excluded_params = [k for k, p in model.named_parameters() if p.requires_grad and not any(k in named_params for named_params in named_param_groups)] + assert len(excluded_params) == 0, f'The following parameters require grad but are excluded from the optimizer: {excluded_params}' + optimizer_cls = getattr(torch.optim, optimizer_config['type']) + optimizer = optimizer_cls([ + { + **param_group_config, + 'params': list(params.values()), + } for param_group_config, params in zip(optimizer_config['params'], named_param_groups) + ]) + return optimizer + + +def parse_lr_lambda(s: str) -> Callable[[int], float]: + epoch = sympy.symbols('epoch') + lr_lambda = sympy.sympify(s) + return sympy.lambdify(epoch, lr_lambda, 'math') + + +def build_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> torch.optim.lr_scheduler._LRScheduler: + if scheduler_config['type'] == "SequentialLR": + child_schedulers = [ + build_lr_scheduler(optimizer, child_scheduler_config) + for child_scheduler_config in scheduler_config['params']['schedulers'] + ] + return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=child_schedulers, milestones=scheduler_config['params']['milestones']) + elif scheduler_config['type'] == "LambdaLR": + lr_lambda = scheduler_config['params']['lr_lambda'] + if isinstance(lr_lambda, str): + lr_lambda = parse_lr_lambda(lr_lambda) + elif isinstance(lr_lambda, list): + lr_lambda = [parse_lr_lambda(l) for l in lr_lambda] + return torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lr_lambda, + ) + else: + scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_config['type']) + scheduler = scheduler_cls(optimizer, **scheduler_config.get('params', {})) + return scheduler \ No newline at end of file diff --git a/moge/utils/__init__.py b/moge/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/moge/utils/alignment.py b/moge/utils/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6bb78766ec1a43a89a4fc931b64f70c5201e2d --- /dev/null +++ b/moge/utils/alignment.py @@ -0,0 +1,416 @@ +from typing import * +import math +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +import utils3d + + +def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min: + "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`." + shape = src.shape[:dim] + (size,) + src.shape[dim + 1:] + minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False) + minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index)) + indices = torch.full(shape, -1, dtype=torch.long, device=src.device) + indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim] + return torch.return_types.min((minimum, indices)) + + +def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs): + batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0] + n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0) + splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args) + splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()} + results = [] + for i in range(n_chunks): + chunk_args = tuple(arg[i] for arg in splited_args) + chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()} + results.append(fn(*chunk_args, **chunk_kwargs)) + + if isinstance(results[0], tuple): + return tuple(torch.cat(r, dim=0) for r in zip(*results)) + else: + return torch.cat(results, dim=0) + + +def _pad_inf(x_: torch.Tensor): + return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1) + + +def _pad_cumsum(cumsum: torch.Tensor): + return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1) + + +def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float): + return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1) + + +def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + """ + If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`. + + w_i must be >= 0. + + ### Parameters: + - `x`: tensor of shape (..., n) + - `y`: tensor of shape (..., n) + - `w`: tensor of shape (..., n) + - `trunc`: optional, float or tensor of shape (..., n) or None + + ### Returns: + - `a`: tensor of shape (...), differentiable + - `loss`: tensor of shape (...), value of loss function at `a`, detached + - `index`: tensor of shape (...), where a = y[idx] / x[idx] + """ + if trunc is None: + x, y, w = torch.broadcast_tensors(x, y, w) + sign = torch.sign(x) + x, y = x * sign, y * sign + y_div_x = y / x.clamp_min(eps) + y_div_x, argsort = y_div_x.sort(dim=-1) + + wx = torch.gather(x * w, dim=-1, index=argsort) + derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True) + search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1) + + a = y_div_x.gather(dim=-1, index=search).squeeze(-1) + index = argsort.gather(dim=-1, index=search).squeeze(-1) + loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1) + + else: + # Reshape to (batch_size, n) for simplicity + x, y, w = torch.broadcast_tensors(x, y, w) + batch_shape = x.shape[:-1] + batch_size = math.prod(batch_shape) + x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1]) + + sign = torch.sign(x) + x, y = x * sign, y * sign + wx, wy = w * x, w * y + xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering + + y_div_x = A = y / x.clamp_min(eps) + B = (wy - trunc) / wx.clamp_min(eps) + C = (wy + trunc) / wx.clamp_min(eps) + with torch.no_grad(): + # Caculate prefix sum by orders of A, B, C + A, A_argsort = A.sort(dim=-1) + Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1) + A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases. + + B, B_argsort = B.sort(dim=-1) + Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1) + B, Q_B = _pad_inf(B), _pad_cumsum(Q_B) + + C, C_argsort = C.sort(dim=-1) + Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1) + C, Q_C = _pad_inf(C), _pad_cumsum(Q_C) + + # Caculate left and right derivative of A + j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1) + j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1) + j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1) + left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) + j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1) + j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1) + j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1) + right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) + + # Find extrema + is_extrema = (left_derivative < 0) & (right_derivative >= 0) + is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema. + where_extrema_batch, where_extrema_index = torch.where(is_extrema) + + # Calculate objective value at extrema + extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,) + MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G) + SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1] + extrema_value = torch.cat([ + _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc) + for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE)) + ]) # (num_extrema,) + + # Find minima among corresponding extrema + minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,) + index = where_extrema_index[indices] + + a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps) + a = a.reshape(batch_shape) + loss = minima.reshape(batch_shape) + index = index.reshape(batch_shape) + + return a, loss, index + + +def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + Align `depth_src` to `depth_tgt` with given constant weights. + + ### Parameters: + - `depth_src: torch.Tensor` of shape (..., N) + - `depth_tgt: torch.Tensor` of shape (..., N) + + """ + scale, _, _ = align(depth_src, depth_tgt, weight, trunc) + + return scale + + +def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + Align `depth_src` to `depth_tgt` with given constant weights. + + ### Parameters: + - `depth_src: torch.Tensor` of shape (..., N) + - `depth_tgt: torch.Tensor` of shape (..., N) + - `weight: torch.Tensor` of shape (..., N) + - `trunc: float` or tensor of shape (..., N) or None + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (...). + """ + dtype, device = depth_src.dtype, depth_src.device + + # Flatten batch dimensions for simplicity + batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1] + batch_size = math.prod(batch_shape) + depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n) + + # Here, we take anchors only for non-zero weights. + # Although the results will be still correct even anchor points have zero weight, + # it is wasting computation and may cause instability in some cases, e.g. too many extrema. + anchors_where_batch, anchors_where_n = torch.where(weight > 0) + + # Stop gradient when solving optimal anchors + with torch.no_grad(): + depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors) + depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors) + + depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n) + depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n) + weight_anchored = weight[anchors_where_batch, :] # (anchors, n) + + scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors) + + loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,) + + # Reproduce by indexing for shorter compute graph + index_1 = anchors_where_n[index_anchor] # (batch_size,) + index_2 = index[index_anchor] # (batch_size,) + + tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1) + tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1) + + scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7) + shift = tgt_1 - scale * src_1 + + scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape) + + return scale, shift + +def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12): + """ + Align `depth_src` to `depth_tgt` with given constant weights using IRLS. + """ + dtype, device = depth_src.dtype, depth_src.device + + w = weight + x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1) + y = depth_tgt + + for i in range(max_iter): + beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1) + w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps) + + return beta[..., 0], beta[..., 1] + + +def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weight: torch.Tensor` of shape (..., N) + + ### Returns: + - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it. + - `b: torch.Tensor` of shape (...) + """ + dtype, device = points_src.dtype, points_src.device + + scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc) + + return scale + + +def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. + It is similar to `align_affine` but scale and shift are applied to different dimensions. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros. + """ + dtype, device = points_src.dtype, points_src.device + + # Flatten batch dimensions for simplicity + batch_shape, n = points_src.shape[:-2], points_src.shape[-2] + batch_size = math.prod(batch_shape) + points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) + + # Take anchors + anchor_where_batch, anchor_where_n = torch.where(weight > 0) + with torch.no_grad(): + zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype) + points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3) + points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3) + + points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3) + points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3) + weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3) + + # Solve optimal scale and shift for each anchor + MAX_ELEMENTS = 2 ** 20 + scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,) + + loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,) + + # Reproduce by indexing for shorter compute graph + index_2 = index[index_anchor] # (batch_size,) [0, 3n) + index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n) + + zeros = torch.zeros((batch_size, n), device=device, dtype=dtype) + points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1) + tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) + tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) + + scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) + shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) + scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) + + return scale, shift + + +def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): + """ + Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. + It is similar to `align_affine` but scale and shift are applied to different dimensions. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3) + """ + dtype, device = points_src.dtype, points_src.device + + # Flatten batch dimensions for simplicity + batch_shape, n = points_src.shape[:-2], points_src.shape[-2] + batch_size = math.prod(batch_shape) + points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) + + # Take anchors + anchor_where_batch, anchor_where_n = torch.where(weight > 0) + + with torch.no_grad(): + points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3) + points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3) + + points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3) + points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3) + weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3) + + # Solve optimal scale and shift for each anchor + MAX_ELEMENTS = 2 ** 20 + scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,) + + # Get optimal scale and shift for each batch element + loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,) + + index_2 = index[index_anchor] # (batch_size,) [0, 3n) + index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n) + + src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) + src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) + + scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) + shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) + + scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) + + return scale, shift + + +def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): + """ + Align `points_src` to `points_tgt` with respect to a Z-axis shift. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3) + """ + dtype, device = points_src.dtype, points_src.device + + shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc) + shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1) + + return shift + + +def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): + """ + Align `points_src` to `points_tgt` with respect to a Z-axis shift. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3) + """ + dtype, device = points_src.dtype, points_src.device + + shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc) + + return shift + + +def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares. + + ### Parameters: + - `x: torch.Tensor` of shape (..., N) + - `y: torch.Tensor` of shape (..., N) + - `w: torch.Tensor` of shape (..., N) + + ### Returns: + - `a: torch.Tensor` of shape (...,) + - `b: torch.Tensor` of shape (...,) + """ + w_sqrt = torch.ones_like(x) if w is None else w.sqrt() + A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1) + B = (w_sqrt * y)[..., None] + a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1) + return a, b \ No newline at end of file diff --git a/moge/utils/download.py b/moge/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..886edbccc81cc0c3daed4d858f641097bdfceee2 --- /dev/null +++ b/moge/utils/download.py @@ -0,0 +1,55 @@ +from pathlib import Path +from typing import * +import requests + +from tqdm import tqdm + + +__all__ = ["download_file", "download_bytes"] + + +def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Initialize local variables + file_path = Path(filepath) + downloaded_bytes = 0 + + # Check if we should resume the download + if resume and file_path.exists(): + downloaded_bytes = file_path.stat().st_size + headers['Range'] = f"bytes={downloaded_bytes}-" + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Calculate the total size to download + total_size = downloaded_bytes + int(response.headers.get('content-length', 0)) + + # Display a progress bar while downloading + with ( + tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar, + open(file_path, 'ab') as file, + ): + # Set the initial position of the progress bar + pbar.update(downloaded_bytes) + + # Write the content to the file in chunks + for chunk in response.iter_content(chunk_size=4096): + file.write(chunk) + pbar.update(len(chunk)) + + +def download_bytes(url: str, headers: dict = None) -> bytes: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Read the content of the response + return response.content + \ No newline at end of file diff --git a/moge/utils/geometry_numpy.py b/moge/utils/geometry_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..6975471e9fb7443d5a615a47de94d49841c789e1 --- /dev/null +++ b/moge/utils/geometry_numpy.py @@ -0,0 +1,406 @@ +from typing import * +from functools import partial +import math + +import cv2 +import numpy as np +from scipy.signal import fftconvolve +import numpy as np +import utils3d + +from .tools import timeit + + +def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return np.mean(x, axis=axis) + else: + w = w.astype(x.dtype) + return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None) + + +def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis) + else: + w = w.astype(x.dtype) + return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps) + + +def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype) + v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + uv = np.stack([u, v], axis=-1) + return uv + + +def focal_to_fov_numpy(focal: np.ndarray): + return 2 * np.arctan(0.5 / focal) + + +def fov_to_focal_numpy(fov: np.ndarray): + return 0.5 / np.tan(fov / 2) + + +def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0]) + fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1]) + return fov_x, fov_y + + +def point_map_to_depth_legacy_numpy(points: np.ndarray): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2) + _, uv = np.broadcast_arrays(points[..., :2], uv) + + # Solve least squares problem + b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2) + A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2) + + M = A.swapaxes(-2, -1) @ A + solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution + + depth = points[..., 2] + shift[..., None, None] + fov_x = np.arctan(width / diagonal / focal) * 2 + fov_y = np.arctan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[: , None] + f = (xy_proj * uv).sum() / np.square(xy_proj).sum() + err = (f * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + xy_proj = xy / (z + optim_shift)[: , None] + optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum() + + return optim_shift, optim_focal + + +def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[: , None] + err = (focal * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + return optim_shift + + +def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)): + import cv2 + assert points.shape[-1] == 3, "Points should (H, W, 3)" + + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + uv = normalized_view_plane_uv_numpy(width=width, height=height) + + if mask is None: + points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3) + uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2) + else: + (points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size) + + if points_lr.size < 2: + return 1., 0. + + if focal is None: + focal, shift = solve_optimal_focal_shift(uv_lr, points_lr) + else: + shift = solve_optimal_shift(uv_lr, points_lr, focal) + + return focal, shift + + +def mask_aware_nearest_resize_numpy( + inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None], + mask: np.ndarray, + size: Tuple[int, int], + return_index: bool = False +) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...). + - `mask`: input 2D mask of shape (..., H, W) + - `size`: target size (width, height) + + ### Returns + - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...). + - `resized_mask`: mask of the resized map of shape (..., target_height, target_width) + - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension. + """ + height, width = mask.shape[-2:] + target_width, target_height = size + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1 + + # Window the original mask and uv + uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) + indices = np.arange(height * width, dtype=np.int32).reshape(height, width) + padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) + windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + + # Gather the target pixels's local window + target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) + target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) + + target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + + # Compute nearest neighbor in the local window for each pixel + dist = np.square(target_window_centers - target_centers[..., None]) + dist = dist[..., 0, :] + dist[..., 1, :] + dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size) + nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1) + nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + target_mask = np.any(target_window_mask, axis=-1) + batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + index = (*batch_indices, nearest_i, nearest_j) + + if inputs is None: + outputs = None + elif isinstance(inputs, np.ndarray): + outputs = inputs[index] + elif isinstance(inputs, Sequence): + outputs = tuple(x[index] for x in inputs) + else: + raise ValueError(f'Invalid input type: {type(inputs)}') + + if return_index: + return outputs, target_mask, index + else: + return outputs, target_mask + + +def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `image`: Input 2D image of shape (..., H, W, C) + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + + if image.shape[-2:] == (height, width): + omit_channel_dim = True + else: + omit_channel_dim = False + if omit_channel_dim: + image = image[..., None] + + image = np.where(mask[..., None], image, 0) + + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1 + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1 + + # Window the original mask and uv (non-copy) + uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) + indices = np.arange(height * width, dtype=np.int32).reshape(height, width) + padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) + windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + + # Gather the target pixels's local window + target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) + target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) + + target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + + # Compute pixel area in the local windows + target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None]) + target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None]) + target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None) + target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0) + + # Weighted sum by area + target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1) + target_mask = np.sum(target_window_area, axis=-1) >= 0.25 + target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1) + + if omit_channel_dim: + target_image = target_image[..., 0] + + return target_image, target_mask + + +def norm3d(x: np.ndarray) -> np.ndarray: + "Faster `np.linalg.norm(x, axis=-1)` for 3D vectors" + return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2])) + + +def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1): + disp = np.where(mask, 1 / depth, 0) + disp_pad = np.pad(disp, (thickness, thickness), constant_values=0) + mask_pad = np.pad(mask, (thickness, thickness), constant_values=False) + kernel_size = 2 * thickness + 1 + disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2] + mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2] + + disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1)) + fg_edge_mask = mask & (disp > (1 + tol) * disp_mean) + bg_edge_mask = mask & (disp_mean > (1 + tol) * disp) + + edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \ + & (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) + + return edge_mask + + +def disk_kernel(radius: int) -> np.ndarray: + """ + Generate disk kernel with given radius. + + Args: + radius (int): Radius of the disk (in pixels). + + Returns: + np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel. + """ + # Create coordinate grid centered at (0,0) + L = np.arange(-radius, radius + 1) + X, Y = np.meshgrid(L, L) + # Generate disk: region inside circle with radius R is 1 + kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32) + # Normalize the kernel + kernel /= np.sum(kernel) + return kernel + + +def disk_blur(image: np.ndarray, radius: int) -> np.ndarray: + """ + Apply disk blur to an image using FFT convolution. + + Args: + image (np.ndarray): Input image, can be grayscale or color. + radius (int): Blur radius (in pixels). + + Returns: + np.ndarray: Blurred image. + """ + if radius == 0: + return image + kernel = disk_kernel(radius) + if image.ndim == 2: + blurred = fftconvolve(image, kernel, mode='same') + elif image.ndim == 3: + channels = [] + for i in range(image.shape[2]): + blurred_channel = fftconvolve(image[..., i], kernel, mode='same') + channels.append(blurred_channel) + blurred = np.stack(channels, axis=-1) + else: + raise ValueError("Image must be 2D or 3D.") + return blurred + + +def depth_of_field( + img: np.ndarray, + disp: np.ndarray, + focus_disp : float, + max_blur_radius : int = 10, +) -> np.ndarray: + """ + Apply depth of field effect to an image. + + Args: + img (numpy.ndarray): (H, W, 3) input image. + depth (numpy.ndarray): (H, W) depth map of the scene. + focus_depth (float): Focus depth of the lens. + strength (float): Strength of the depth of field effect. + max_blur_radius (int): Maximum blur radius (in pixels). + + Returns: + numpy.ndarray: (H, W, 3) output image with depth of field effect applied. + """ + # Precalculate dialated depth map for each blur radius + max_disp = np.max(disp) + disp = disp / max_disp + focus_disp = focus_disp / max_disp + dilated_disp = [] + for radius in range(max_blur_radius + 1): + dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1)) + + # Determine the blur radius for each pixel based on the depth map + blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32) + for radius in range(max_blur_radius + 1): + dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32) + mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp) + blur_radii[mask] = dialted_blur_radii[mask] + blur_radii = np.clip(blur_radii, 0, max_blur_radius) + blur_radii = cv2.blur(blur_radii, (5, 5)) + + # Precalculate the blured image for each blur radius + unique_radii = np.unique(blur_radii) + precomputed = {} + for radius in range(max_blur_radius + 1): + if radius not in unique_radii: + continue + precomputed[radius] = disk_blur(img, radius) + + # Composit the blured image for each pixel + output = np.zeros_like(img) + for r in unique_radii: + mask = blur_radii == r + output[mask] = precomputed[r][mask] + + return output diff --git a/moge/utils/geometry_torch.py b/moge/utils/geometry_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5dbe965a42d0e0b3cbe53eb213bdcb829f8243 --- /dev/null +++ b/moge/utils/geometry_torch.py @@ -0,0 +1,354 @@ +from typing import * +import math +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +import utils3d + +from .tools import timeit +from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift + + +def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.mean(dim=dim, keepdim=keepdim) + else: + w = w.to(x.dtype) + return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps) + + +def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal() + + +def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).log().mean(dim=dim).exp() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp() + + +def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: + kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2)) + kernel = kernel / kernel.sum() + kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size) + input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate') + input = F.conv2d(input, kernel, groups=input.shape[1]) + return input + + +def focal_to_fov(focal: torch.Tensor): + return 2 * torch.atan(0.5 / focal) + + +def fov_to_focal(fov: torch.Tensor): + return 0.5 / torch.tan(fov / 2) + + +def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12): + return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1)) + +def intrinsics_to_fov(intrinsics: torch.Tensor): + """ + Returns field of view in radians from normalized intrinsics matrix. + ### Parameters: + - intrinsics: torch.Tensor of shape (..., 3, 3) + + ### Returns: + - fov_x: torch.Tensor of shape (...) + - fov_y: torch.Tensor of shape (...) + """ + focal_x = intrinsics[..., 0, 0] + focal_y = intrinsics[..., 1, 1] + return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y) + + +def point_map_to_depth_legacy(points: torch.Tensor): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + # Solve least squares problem + b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2) + A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2) + + M = A.transpose(-2, -1) @ A + solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution.unbind(-1) + + depth = points[..., 2] + shift[..., None, None] + fov_x = torch.atan(width / diagonal / focal) * 2 + fov_y = torch.atan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def view_plane_uv_to_focal(uv: torch.Tensor): + normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype) + focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12) + return focal + + +def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)): + """ + Recover the depth map and FoV from a point map with unknown z shift and focal. + + Note that it assumes: + - the optical center is at the center of the map + - the map is undistorted + - the map is isometric in the x and y directions + + ### Parameters: + - `points: torch.Tensor` of shape (..., H, W, 3) + - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps. + + ### Returns: + - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map + - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space + """ + shape = points.shape + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + points = points.reshape(-1, *shape[-3:]) + mask = None if mask is None else mask.reshape(-1, *shape[-3:-1]) + focal = focal.reshape(-1) if focal is not None else None + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1) + uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0) + mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0 + + uv_lr_np = uv_lr.cpu().numpy() + points_lr_np = points_lr.detach().cpu().numpy() + focal_np = focal.cpu().numpy() if focal is not None else None + mask_lr_np = None if mask is None else mask_lr.cpu().numpy() + optim_shift, optim_focal = [], [] + for i in range(points.shape[0]): + points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]] + uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]] + if uv_lr_i_np.shape[0] < 2: + optim_focal.append(1) + optim_shift.append(0) + continue + if focal is None: + optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np) + optim_focal.append(float(optim_focal_i)) + else: + optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i]) + optim_shift.append(float(optim_shift_i)) + optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + + if focal is None: + optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + else: + optim_focal = focal.reshape(shape[:-3]) + + return optim_focal, optim_shift + + +def mask_aware_nearest_resize( + inputs: Union[torch.Tensor, Sequence[torch.Tensor], None], + mask: torch.BoolTensor, + size: Tuple[int, int], + return_index: bool = False +) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...). + - `mask`: input 2D mask of shape (..., H, W) + - `size`: target size (target_width, target_height) + + ### Returns + - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...). + - `resized_mask`: mask of the resized map of shape (..., target_height, target_width) + - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, . + """ + height, width = mask.shape[-2:] + target_width, target_height = size + device = mask.device + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1 + + # Window the original mask and uv + uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device) + indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width) + padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1)) + windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device) + target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device) + target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device) + + target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + target_window_indices = target_window_indices.expand_as(target_window_mask) + + # Compute nearest neighbor in the local window for each pixel + dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size) + nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1) + nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width) + target_mask = torch.any(target_window_mask, dim=-1) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + index = (*batch_indices, nearest_i, nearest_j) + + if inputs is None: + outputs = None + elif isinstance(inputs, torch.Tensor): + outputs = inputs[index] + elif isinstance(inputs, Sequence): + outputs = tuple(x[index] for x in inputs) + else: + raise ValueError(f'Invalid input type: {type(inputs)}') + + if return_index: + return outputs, target_mask, index + else: + return outputs, target_mask + + +def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3): + *batch_shape, height, width = depth.shape + depth = depth.reshape(-1, 1, height, width) + mask = mask.reshape(-1, 1, height, width) + if pooler =='max': + pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + output_mask = pooled_depth > depth * (1 + rtol) + elif pooler =='min': + pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + output_mask = pooled_depth < depth * (1 - rtol) + else: + raise ValueError(f'Unsupported pooler: {pooler}') + output_mask = output_mask.reshape(*batch_shape, height, width) + return output_mask + + +def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1): + device, dtype = depth.device, depth.dtype + + disp = torch.where(mask, 1 / depth, 0) + disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0) + mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False) + disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2] + mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2] + + x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype) + A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3] + A = mask_window[..., None] * A + I = torch.eye(3, device=device, dtype=dtype) + + affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2] + diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0) + + edge_mask = mask & (diff > tol).any(dim=-1) + + disp_mean = weighted_mean(disp_window, mask_window, dim=-1) + fg_edge_mask = edge_mask & (disp > disp_mean) + # fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size) + bg_edge_mask = edge_mask & ~fg_edge_mask + return fg_edge_mask, bg_edge_mask + + +def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1): + device, dtype = depth.device, depth.dtype + + disp = torch.where(mask, 1 / depth, 0) + disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0) + mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False) + disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2] + mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2] + + disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1)) + fg_edge_mask = mask & (disp / disp_mean > 1 + tol) + bg_edge_mask = mask & (disp_mean / disp > 1 + tol) + + fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool() + bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool() + + return fg_edge_mask, bg_edge_mask + + +def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor: + kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool) + for _ in range(iterations): + input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1)) + mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1)) + if filter =='min': + input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values) + elif filter =='max': + input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values) + elif filter == 'mean': + input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1))) + elif filter =='median': + input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values) + mask = mask_window.any(dim=(-2, -1)) + return input, mask + + +def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor: + device, dtype = depth.device, depth.dtype + height, width = depth.shape[-2:] + radius = kernel_size // 2 + + duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device) + + log_depth = depth.clamp_min_(eps).log() + log_depth_diff = utils3d.torch.sliding_window_2d(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None] + + weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square()) + tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps) + + uv = utils3d.torch.image_uv(height=height, width=width, device=device, dtype=dtype) + K_inv = torch.inverse(intrinsics) + + grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \ + / (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2) + laplacian = (weight * ((utils3d.torch.sliding_window_2d(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1)) + + laplacian = laplacian.clamp(-0.1, 0.1) + log_depth_refine = log_depth.clone() + + for _ in range(iterations): + log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.torch.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp) + + depth_refine = log_depth_refine.exp() + + return depth_refine \ No newline at end of file diff --git a/moge/utils/io.py b/moge/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..c185a7d449390a0c4b037bae293fbbe6de8d7d4e --- /dev/null +++ b/moge/utils/io.py @@ -0,0 +1,236 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from typing import IO +import zipfile +import json +import io +from typing import * +from pathlib import Path +import re +from PIL import Image, PngImagePlugin + +import numpy as np +import cv2 + +from .tools import timeit + + +def save_glb( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_uvs: np.ndarray, + color_texture: np.ndarray, + normal_texture: np.ndarray = None, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + faces=faces, + visual = trimesh.visual.texture.TextureVisuals( + uv=vertex_uvs, + material=trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(color_texture), + normalTexture=Image.fromarray(normal_texture), + metallicFactor=0.5, + roughnessFactor=1.0 + ) + ), + process=False + ).export(save_path) + + +def save_ply( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_colors: np.ndarray, + vertex_normals: Optional[np.ndarray] = None, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=vertex_colors, + vertex_normals=vertex_normals, + process=False + ).export(save_path) + + +def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray: + """ + Read a image, return uint8 RGB array of shape (H, W, 3). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + return image + + +def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95): + """ + Write a image, input uint8 RGB array of shape (H, W, 3). + """ + data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes() + if isinstance(path, (str, os.PathLike)): + Path(path).write_bytes(data) + else: + path.write(data) + + +def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]: + """ + Read a depth image, return float32 depth array of shape (H, W). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + pil_image = Image.open(io.BytesIO(data)) + near = float(pil_image.info.get('near')) + far = float(pil_image.info.get('far')) + unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None + depth = np.array(pil_image) + mask_nan, mask_inf = depth == 0, depth == 65535 + depth = (depth.astype(np.float32) - 1) / 65533 + depth = near ** (1 - depth) * far ** depth + depth[mask_nan] = np.nan + depth[mask_inf] = np.inf + return depth, unit + + +def write_depth( + path: Union[str, os.PathLike, IO], + depth: np.ndarray, + unit: float = None, + max_range: float = 1e5, + compression_level: int = 7, +): + """ + Encode and write a depth image as 16-bit PNG format. + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to write to. + - `depth: np.ndarray` + The depth array, float32 array of shape (H, W). + May contain `NaN` for invalid values and `Inf` for infinite values. + - `unit: float = None` + The unit of the depth values. + + Depth values are encoded as follows: + - 0: unknown + - 1 ~ 65534: depth values in logarithmic + - 65535: infinity + + metadata is stored in the PNG file as text fields: + - `near`: the minimum depth value + - `far`: the maximum depth value + - `unit`: the unit of the depth values (optional) + """ + mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth) + + depth = depth.astype(np.float32) + mask_finite = depth + near = max(depth[mask_values].min(), 1e-5) + far = max(near * 1.1, min(depth[mask_values].max(), near * max_range)) + depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534 + depth[mask_nan] = 0 + depth[mask_inf] = 65535 + + pil_image = Image.fromarray(depth) + pnginfo = PngImagePlugin.PngInfo() + pnginfo.add_text('near', str(near)) + pnginfo.add_text('far', str(far)) + if unit is not None: + pnginfo.add_text('unit', str(unit)) + pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) + + +def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]: + """ + Read a segmentation mask + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to read from. + ### Returns: + - `Tuple[np.ndarray, Dict[str, int]]` + A tuple containing: + - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W). + - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}. + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + pil_image = Image.open(io.BytesIO(data)) + labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None + mask = np.array(pil_image) + return mask, labels + + +def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7): + """ + Write a segmentation mask and label mapping, as PNG format. + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to write to. + - `mask: np.ndarray` + The segmentation mask, uint8 or uint16 array of shape (H, W). + - `labels: Dict[str, int] = None` + The label mapping, a dictionary of {label_name: label_id}. + - `compression_level: int = 7` + The compression level for PNG compression. + """ + assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}" + pil_image = Image.fromarray(mask) + pnginfo = PngImagePlugin.PngInfo() + if labels is not None: + labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':')) + pnginfo.add_text('labels', labels_json) + pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) + + + +def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray: + """ + Read a normal image, return float32 normal array of shape (H, W, 3). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB) + mask_nan = np.all(normal == 0, axis=-1) + normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0] + normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12) + normal[mask_nan] = np.nan + return normal + + +def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray: + """ + Write a normal image, input float32 normal array of shape (H, W, 3). + """ + mask_nan = np.isnan(normal).any(axis=-1) + normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16) + normal[mask_nan] = 0 + data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes() + if isinstance(path, (str, os.PathLike)): + Path(path).write_bytes(data) + else: + path.write(data) + + +def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + +def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]): + Path(path).write_text(json.dumps(meta)) \ No newline at end of file diff --git a/moge/utils/panorama.py b/moge/utils/panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9d121c3c189770a7fd9f88be66f74f1ba5cfd3 --- /dev/null +++ b/moge/utils/panorama.py @@ -0,0 +1,191 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +from typing import * +import itertools +import json +import warnings + +import cv2 +import numpy as np +from numpy import ndarray +from tqdm import tqdm, trange +from scipy.sparse import csr_array, hstack, vstack +from scipy.ndimage import convolve +from scipy.sparse.linalg import lsmr + +import utils3d + + +def get_panorama_cameras(): + vertices, _ = utils3d.numpy.icosahedron() + intrinsics = utils3d.numpy.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90)) + extrinsics = utils3d.numpy.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32) + return extrinsics, [intrinsics] * len(vertices) + + +def spherical_uv_to_directions(uv: np.ndarray): + theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi + directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1) + return directions + + +def directions_to_spherical_uv(directions: np.ndarray): + directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True) + u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0 + v = np.arccos(directions[..., 2]) / np.pi + return np.stack([u, v], axis=-1) + + +def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int): + height, width = image.shape[:2] + uv = utils3d.numpy.image_uv(width=resolution, height=resolution) + splitted_images = [] + for i in range(len(extrinsics)): + spherical_uv = directions_to_spherical_uv(utils3d.numpy.unproject_cv(uv, extrinsics=extrinsics[i], intrinsics=intrinsics[i])) + pixels = utils3d.numpy.uv_to_pixel(spherical_uv, width=width, height=height).astype(np.float32) + + splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR) + splitted_images.append(splitted_image) + return splitted_images + + +def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]: + grid_index = np.arange(height * width).reshape(height, width) + grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge') + grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge') + + data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1) + indices = np.stack([ + grid_index[1:-1, 1:-1], + grid_index[:-2, 1:-1], # up + grid_index[2:, 1:-1], # down + grid_index[1:-1, :-2], # left + grid_index[1:-1, 2:] # right + ], axis=-1).reshape(-1) + indptr = np.arange(0, height * width * 5 + 1, 5) + A = csr_array((data, indices, indptr), shape=(height * width, height * width)) + + return A + + +def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]: + grid_index = np.arange(width * height).reshape(height, width) + if wrap_x: + grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap') + if wrap_y: + grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap') + + data = np.concatenate([ + np.concatenate([ + np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j] + -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1] + ], axis=1).reshape(-1), + np.concatenate([ + np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j] + -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j] + ], axis=1).reshape(-1), + ]) + indices = np.concatenate([ + np.concatenate([ + grid_index[:, :-1].reshape(-1, 1), + grid_index[:, 1:].reshape(-1, 1), + ], axis=1).reshape(-1), + np.concatenate([ + grid_index[:-1, :].reshape(-1, 1), + grid_index[1:, :].reshape(-1, 1), + ], axis=1).reshape(-1), + ]) + indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2) + A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width)) + + return A + + +def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]): + if max(width, height) > 256: + panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics) + panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR) + else: + panorama_depth_init = None + + uv = utils3d.numpy.image_uv(width=width, height=height) + spherical_directions = spherical_uv_to_directions(uv) + + # Warp each view to the panorama + panorama_log_distance_grad_maps, panorama_grad_masks = [], [] + panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], [] + panorama_pred_masks = [] + for i in range(len(distance_maps)): + projected_uv, projected_depth = utils3d.numpy.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i]) + projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1) + + projected_pixels = utils3d.numpy.uv_to_pixel(np.clip(projected_uv, 0, 1), width=distance_maps[i].shape[1], height=distance_maps[i].shape[0]).astype(np.float32) + + log_splitted_distance = np.log(distance_maps[i]) + panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0) + panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0) + + # calculate gradient map + padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap') + grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :] + + padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap') + mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :] + + panorama_log_distance_grad_maps.append((grad_x, grad_y)) + panorama_grad_masks.append((mask_x, mask_y)) + + # calculate laplacian map + padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge') + padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') + laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1] + + padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge') + padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') + mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5 + + panorama_log_distance_laplacian_maps.append(laplacian) + panorama_laplacian_masks.append(mask) + + panorama_pred_masks.append(panorama_pred_mask) + + panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0) + panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0) + panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0) + panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0) + + panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3) + panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3) + + panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0) + panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0) + panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3) + + grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1) + grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1) + grad_mask = np.concatenate([grad_x_mask, grad_y_mask]) + laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1) + + # Solve overdetermined system + A = vstack([ + grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask], + poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask], + ]) + b = np.concatenate([ + panorama_log_distance_grad_x.reshape(-1)[grad_x_mask], + panorama_log_distance_grad_y.reshape(-1)[grad_y_mask], + panorama_laplacian_map.reshape(-1)[laplacian_mask] + ]) + x, *_ = lsmr( + A, b, + atol=1e-5, btol=1e-5, + x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None, + show=False, + ) + + panorama_depth = np.exp(x).reshape(height, width).astype(np.float32) + panorama_mask = np.any(panorama_pred_masks, axis=0) + + return panorama_depth, panorama_mask + diff --git a/moge/utils/pipeline.py b/moge/utils/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..daa522e987317e949899d4159e61d7a7066e1fba --- /dev/null +++ b/moge/utils/pipeline.py @@ -0,0 +1,503 @@ +from typing import * +from abc import abstractmethod +from queue import Empty, Full +from threading import Thread +from queue import Queue +from multiprocessing import Process +from threading import Thread, Event +import multiprocessing +import threading +import inspect +import time +import uuid +from copy import deepcopy +import itertools +import functools + +__all__ = [ + 'Node', + 'Link', + 'ConcurrentNode', + 'Worker', + 'WorkerFunction', + 'Provider', + 'ProviderFunction', + 'Sequential', + 'Batch', + 'Unbatch', + 'Parallel', + 'Graph', + 'Buffer', +] + +TERMINATE_CHECK_INTERVAL = 0.5 + + +class _ItemWrapper: + def __init__(self, data: Any, id: Union[int, List[int]] = None): + self.data = data + self.id = id + + +class Terminate(Exception): + pass + + +def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper: + while True: + try: + item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL)) + if terminate_flag.is_set(): + raise Terminate() + return item + except Empty: + if terminate_flag.is_set(): + raise Terminate() + + if timeout is not None: + timeout -= TERMINATE_CHECK_INTERVAL + if timeout <= 0: + raise Empty() + + +def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event): + while True: + try: + queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL) + if terminate_flag.is_set(): + raise Terminate() + return + except Full: + if terminate_flag.is_set(): + raise Terminate() + +class Node: + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + self.input: Queue = Queue(maxsize=in_buffer_size) + self.output: Queue = Queue(maxsize=out_buffer_size) + self.in_buffer_size = in_buffer_size + self.out_buffer_size = out_buffer_size + + @abstractmethod + def start(self): + pass + + @abstractmethod + def terminate(self): + pass + + def stop(self): + self.terminate() + self.join() + + @abstractmethod + def join(self): + pass + + def put(self, data: Any, key: str = None, block: bool = True) -> None: + item = _ItemWrapper(data) + self.input.put(item, block=block) + + def get(self, key: str = None, block: bool = True) -> Any: + item: _ItemWrapper = self.output.get(block=block) + return item.data + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.terminate() + self.join() + + +class ConcurrentNode(Node): + job: Union[Thread, Process] + + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(in_buffer_size, out_buffer_size) + self.running_as = running_as + + @abstractmethod + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + pass + + def start(self): + if self.running_as == 'thread': + terminate_flag = threading.Event() + job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + elif self.running_as == 'process': + terminate_flag = multiprocessing.Event() + job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + job.start() + self.job = job + self.terminate_flag = terminate_flag + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.job.join() + + +class Worker(ConcurrentNode): + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread is started, to initialize any resources that is only held in the thread. + """ + pass + + @abstractmethod + def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]: + """ + This method defines the job that the node should do for each input item. + A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue. + The method is executed concurrently with other nodes. + """ + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + while True: + item = _get_queue_item(input, terminate_flag) + result = self.work(item.data) + _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag) + + except Terminate: + return + + +class Provider(ConcurrentNode): + """ + A node that provides data to successive nodes. It takes no input and provides data to the output queue. + """ + def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None: + super().__init__(running_as, 0, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process. + """ + pass + + @abstractmethod + def provide(self) -> Generator[Any, None, None]: + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + for data in self.provide(): + _put_queue_item(output, _ItemWrapper(data), terminate_flag) + except Terminate: + return + + +class WorkerFunction(Worker): + def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + self.fn = fn + + def work(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + +class ProviderFunction(Provider): + def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None: + super().__init__(running_as, out_buffer_size) + self.fn = fn + + def provide(self): + for item in self.fn(): + yield item + + +class Link: + def __init__(self, src: Queue, dst: Queue): + self.src = src + self.dst = dst + + def _thread_fn(self): + try: + while True: + item = _get_queue_item(self.src, self.terminate_flag) + _put_queue_item(self.dst, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.thread = Thread(target=self._thread_fn) + self.thread.start() + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.thread.join() + + +class Graph(Node): + """ + Graph pipeline of nodes and links + """ + nodes: List[Node] + links: List[Link] + + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + self.links = [] + + def add(self, node: Node): + self.nodes.append(node) + + def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]): + """ + Links the output of the source node to the input of the destination node. + If the source or destination node is None, the pipeline's input or output is used. + """ + src_queue = self.input if src is None else src.output + dst_queue = self.output if dst is None else dst.input + self.links.append(Link(src_queue, dst_queue)) + + def chain(self, nodes: Iterable[Node]): + """ + Link the output of each node to the input of the next node. + """ + nodes = list(nodes) + for i in range(len(nodes) - 1): + self.link(nodes[i], nodes[i + 1]) + + def start(self): + for node in self.nodes: + node.start() + for link in self.links: + link.start() + + def terminate(self): + for node in self.nodes: + node.terminate() + for link in self.links: + link.terminate() + + def join(self): + for node in self.nodes: + node.join() + for link in self.links: + link.join() + + def __iter__(self): + providers = [node for node in self.nodes if isinstance(node, Provider)] + if len(providers) == 0: + raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.") + with self: + # while all(provider.job.is_alive() for provider in providers): + while True: + yield self.get() + + def __call__(self, data: Any) -> Any: + """ + Submit data to the pipeline's input queue, and return the output data asynchronously. + NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work. + """ + # TODO + + +class Sequential(Graph): + """ + Pipeline of nodes in sequential order, where each node takes the output of the previous node as input. + The order of input and output items is preserved (FIFO) + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute sequentially. + ### Parameters: + - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + self.chain([None, *self.nodes, None]) + + +class Parallel(Node): + """ + A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available. + NOTE: It is FIFO if and only if all the nested nodes are FIFO. + """ + nodes: List[Node] + + def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.nodes.append(node) + self.output_order = Queue() + self.lock = threading.Lock() + + def _in_thread_fn(self, node: Node): + try: + while True: + with self.lock: + # A better idea: first make sure its node is vacant, then get it a new item. + # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node. + # This could lead to suboptimal scheduling. + item = _get_queue_item(self.input, self.terminate_flag) + self.output_order.put(node.output) + _put_queue_item(node.input, item, self.terminate_flag) + except Terminate: + return + + def _out_thread_fn(self): + try: + while True: + queue = _get_queue_item(self.output_order, self.terminate_flag) + item = _get_queue_item(queue, self.terminate_flag) + _put_queue_item(self.output, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.in_threads = [] + for node in self.nodes: + thread = Thread(target=self._in_thread_fn, args=(node,)) + thread.start() + self.in_threads.append(thread) + thread = Thread(target=self._out_thread_fn) + thread.start() + self.out_thread = thread + for node in self.nodes: + node.start() + + def terminate(self): + self.terminate_flag.set() + for node in self.nodes: + node.terminate() + + def join(self): + for thread in self.in_threads: + thread.join() + self.out_thread.join() + + +class UnorderedParallel(Graph): + """ + Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available. + NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input. + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node. + ### Parameters: + - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + for i in range(len(nodes)): + self.chain([None, self.nodes[i], None]) + + +class Batch(ConcurrentNode): + """ + Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes. + The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node, + i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size. + """ + def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1): + assert batch_size > 0, "Batch size must be greater than 0." + super().__init__('thread', in_buffer_size, out_buffer_size) + self.batch_size = batch_size + self.patience = patience + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch_id, batch_data = [], [] + # Try to fill the batch + for i in range(self.batch_size): + if i == 0 or self.patience is None: + timeout = None + else: + timeout = self.patience - (time.time() - earliest_time) + if timeout < 0: + break + try: + item = _get_queue_item(input, terminate_flag, timeout) + except Empty: + break + + if i == 0: + earliest_time = time.time() + batch_data.append(item.data) + batch_id.append(item.id) + + batch = _ItemWrapper(batch_data, batch_id) + _put_queue_item(output, batch, terminate_flag) + except Terminate: + return + + +class Unbatch(ConcurrentNode): + """ + Ungroups every batch (a list of items) into individual items and passes them to successive nodes. + """ + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__('thread', in_buffer_size, out_buffer_size) + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch = _get_queue_item(input, terminate_flag) + for id, data in zip(batch.id or itertools.repeat(None), batch.data): + item = _ItemWrapper(data, id) + _put_queue_item(output, item, terminate_flag) + except Terminate: + return + + +class Buffer(Node): + "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time." + def __init__(self, size: int): + super().__init__(size, size) + self.size = size + self.input = self.output = Queue(maxsize=size) \ No newline at end of file diff --git a/moge/utils/tools.py b/moge/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..3687f6938fe34433d149a1a8405be7eed5f23c37 --- /dev/null +++ b/moge/utils/tools.py @@ -0,0 +1,289 @@ +from typing import * +import time +from pathlib import Path +from numbers import Number +from functools import wraps +import warnings +import math +import json +import os +import importlib +import importlib.util + + +def catch_exception(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + import traceback + print(f"Exception in {fn.__name__}", end='r') + # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())}) + traceback.print_exc(chain=False) + time.sleep(0.1) + return None + return wrapper + + +class CallbackOnException: + def __init__(self, callback: Callable, exception: type): + self.exception = exception + self.callback = callback + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if isinstance(exc_val, self.exception): + self.callback() + return True + return False + +def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: + for k, v in d.items(): + if isinstance(v, dict): + for sub_key in traverse_nested_dict_keys(v): + yield (k, ) + sub_key + else: + yield (k, ) + + +def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): + for k in keys: + d = d.get(k, default) + if d is None: + break + return d + +def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value + + +def key_average(list_of_dicts: list) -> Dict[str, Any]: + """ + Returns a dictionary with the average value of each key in the input list of dictionaries. + """ + _nested_dict_keys = set() + for d in list_of_dicts: + _nested_dict_keys.update(traverse_nested_dict_keys(d)) + _nested_dict_keys = sorted(_nested_dict_keys) + result = {} + for k in _nested_dict_keys: + values = [] + for d in list_of_dicts: + v = get_nested_dict(d, k) + if v is not None and not math.isnan(v): + values.append(v) + avg = sum(values) / len(values) if values else float('nan') + set_nested_dict(result, k, avg) + return result + + +def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: + """ + Flattens a nested dictionary into a single-level dictionary, with keys as tuples. + """ + items = [] + if parent_key is None: + parent_key = () + for k, v in d.items(): + new_key = parent_key + (k, ) + if isinstance(v, MutableMapping): + items.extend(flatten_nested_dict(v, new_key).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. + """ + result = {} + for k, v in d.items(): + sub_dict = result + for k_ in k[:-1]: + if k_ not in sub_dict: + sub_dict[k_] = {} + sub_dict = sub_dict[k_] + sub_dict[k[-1]] = v + return result + + +def read_jsonl(file): + import json + with open(file, 'r') as f: + data = f.readlines() + return [json.loads(line) for line in data] + + +def write_jsonl(data: List[dict], file): + import json + with open(file, 'w') as f: + for item in data: + f.write(json.dumps(item) + '\n') + + +def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): + import pandas as pd + data = [flatten_nested_dict(d) for d in data] + df = pd.DataFrame(data) + df = df.sort_index(axis=1) + df.columns = pd.MultiIndex.from_tuples(df.columns) + return df + + +def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): + if isinstance(d, str): + for old, new in mapping.items(): + d = d.replace(old, new) + elif isinstance(d, list): + for i, item in enumerate(d): + d[i] = recursive_replace(item, mapping) + elif isinstance(d, dict): + for k, v in d.items(): + d[k] = recursive_replace(v, mapping) + return d + + +class timeit: + _history: Dict[str, List['timeit']] = {} + + def __init__(self, name: str = None, verbose: bool = True, average: bool = False): + self.name = name + self.verbose = verbose + self.start = None + self.end = None + self.average = average + if average and name not in timeit._history: + timeit._history[name] = [] + + def __call__(self, func: Callable): + import inspect + if inspect.iscoroutinefunction(func): + async def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = await func(*args, **kwargs) + return ret + return wrapper + else: + def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = func(*args, **kwargs) + return ret + return wrapper + + def __enter__(self): + self.start = time.time() + return self + + @property + def time(self) -> float: + assert self.start is not None, "Time not yet started." + assert self.end is not None, "Time not yet ended." + return self.end - self.start + + @property + def average_time(self) -> float: + assert self.average, "Average time not available." + return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) + + @property + def history(self) -> List['timeit']: + return timeit._history.get(self.name, []) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.time() + if self.average: + timeit._history[self.name].append(self) + if self.verbose: + if self.average: + avg = self.average_time + print(f"{self.name or 'It'} took {avg:.6f} seconds in average.") + else: + print(f"{self.name or 'It'} took {self.time:.6f} seconds.") + + +def strip_common_prefix_suffix(strings: List[str]) -> List[str]: + first = strings[0] + + for start in range(len(first)): + if any(s[start] != strings[0][start] for s in strings): + break + + for end in range(1, min(len(s) for s in strings)): + if any(s[-end] != first[-end] for s in strings): + break + + return [s[start:len(s) - end + 1] for s in strings] + + +def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): + from concurrent.futures import ThreadPoolExecutor + from contextlib import nullcontext + from tqdm import tqdm + + if pbar is not None: + pbar.total = len(inputs) if hasattr(inputs, '__len__') else None + else: + pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) + + def decorator(fn: Callable): + with ( + ThreadPoolExecutor(max_workers=num_workers) as executor, + pbar + ): + pbar.refresh() + @catch_exception + @suppress_traceback + def _fn(input): + ret = fn(input) + pbar.update() + return ret + executor.map(_fn, inputs) + executor.shutdown(wait=True) + + return decorator + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + return wrapper + + +class no_warnings: + def __init__(self, action: str = 'ignore', **kwargs): + self.action = action + self.filter_kwargs = kwargs + + def __call__(self, fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter(self.action, **self.filter_kwargs) + return fn(*args, **kwargs) + return wrapper + + def __enter__(self): + self.warnings_manager = warnings.catch_warnings() + self.warnings_manager.__enter__() + warnings.simplefilter(self.action, **self.filter_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) + + +def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module \ No newline at end of file diff --git a/moge/utils/vis.py b/moge/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9c2378b58ec26ac5067b7ffcbd749a8ad968ce --- /dev/null +++ b/moge/utils/vis.py @@ -0,0 +1,65 @@ +from typing import * + +import numpy as np +import matplotlib + + +def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is None: + depth = np.where(depth > 0, depth, np.nan) + else: + depth = np.where((depth > 0) & mask, depth, np.nan) + disp = 1 / depth + if normalize: + min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99) + disp = (disp - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + depth = np.where(mask, depth, np.nan) + + min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999) + depth = (depth - min_depth) / (max_depth - min_depth) + colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + disparity = np.where(mask, disparity, np.nan) + + if normalize: + min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999) + disparity = (disparity - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray: + colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)[..., :3] + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + if mask is not None: + normal = np.where(mask[..., None], normal, 0) + normal = normal * [0.5, -0.5, -0.5] + 0.5 + normal = (normal.clip(0, 1) * 255).astype(np.uint8) + return normal + + +def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None): + vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map)) + cmap = matplotlib.colormaps[cmap] + colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3] + if mask is not None: + colorized_error_map = np.where(mask[..., None], colorized_error_map, 0) + colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8)) + return colorized_error_map diff --git a/moge/utils/webfile.py b/moge/utils/webfile.py new file mode 100644 index 0000000000000000000000000000000000000000..1e98abf8413e1c9f408849b74f4d2025d25511b6 --- /dev/null +++ b/moge/utils/webfile.py @@ -0,0 +1,73 @@ +import requests +from typing import * + +__all__ = ["WebFile"] + + +class WebFile: + def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None): + self.url = url + self.session = session or requests.Session() + self.session.headers.update(headers or {}) + self._offset = 0 + self.size = size if size is not None else self._fetch_size() + + def _fetch_size(self): + with self.session.get(self.url, stream=True) as response: + response.raise_for_status() + content_length = response.headers.get("Content-Length") + if content_length is None: + raise ValueError("Missing Content-Length in header") + return int(content_length) + + def _fetch_data(self, offset: int, n: int) -> bytes: + headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"} + response = self.session.get(self.url, headers=headers) + response.raise_for_status() + return response.content + + def seekable(self) -> bool: + return True + + def tell(self) -> int: + return self._offset + + def available(self) -> int: + return self.size - self._offset + + def seek(self, offset: int, whence: int = 0) -> None: + if whence == 0: + new_offset = offset + elif whence == 1: + new_offset = self._offset + offset + elif whence == 2: + new_offset = self.size + offset + else: + raise ValueError("Invalid value for whence") + + self._offset = max(0, min(new_offset, self.size)) + + def read(self, n: Optional[int] = None) -> bytes: + if n is None or n < 0: + n = self.available() + else: + n = min(n, self.available()) + + if n == 0: + return b'' + + data = self._fetch_data(self._offset, n) + self._offset += len(data) + + return data + + def close(self) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + \ No newline at end of file diff --git a/moge/utils/webzipfile.py b/moge/utils/webzipfile.py new file mode 100644 index 0000000000000000000000000000000000000000..25ed1d3cd34720335eb001d77a278539ffef569b --- /dev/null +++ b/moge/utils/webzipfile.py @@ -0,0 +1,128 @@ +from typing import * +import io +import os +from zipfile import ( + ZipInfo, BadZipFile, ZipFile, ZipExtFile, + sizeFileHeader, structFileHeader, stringFileHeader, + _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS, + _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED +) +import struct +from requests import Session + +from .webfile import WebFile + + +class _SharedWebFile(WebFile): + def __init__(self, webfile: WebFile, pos: int): + super().__init__(webfile.url, webfile.session, size=webfile.size) + self.seek(pos) + + +class WebZipFile(ZipFile): + "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads." + def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None): + """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x', + or append 'a'.""" + webf = WebFile(url, session=session, headers=headers) + super().__init__(webf, mode='r') + + def open(self, name, mode="r", pwd=None, *, force_zip64=False): + """Return file-like object for 'name'. + + name is a string for the file name within the ZIP file, or a ZipInfo + object. + + mode should be 'r' to read a file already in the ZIP file, or 'w' to + write to a file newly added to the archive. + + pwd is the password to decrypt files (only used for reading). + + When writing, if the file size is not known in advance but may exceed + 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large + files. If the size is known in advance, it is best to pass a ZipInfo + instance for name, with zinfo.file_size set. + """ + if mode not in {"r", "w"}: + raise ValueError('open() requires mode "r" or "w"') + if pwd and (mode == "w"): + raise ValueError("pwd is only supported for reading files") + if not self.fp: + raise ValueError( + "Attempt to use ZIP archive that was already closed") + + assert mode == "r", "Only read mode is supported for now" + + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name + elif mode == 'w': + zinfo = ZipInfo(name) + zinfo.compress_type = self.compression + zinfo._compresslevel = self.compresslevel + else: + # Get info object for name + zinfo = self.getinfo(name) + + if mode == 'w': + return self._open_to_write(zinfo, force_zip64=force_zip64) + + if self._writing: + raise ValueError("Can't read from the ZIP file while there " + "is an open writing handle on it. " + "Close the writing handle before trying to read.") + + # Open for reading: + self._fileRefCnt += 1 + zef_file = _SharedWebFile(self.fp, zinfo.header_offset) + + try: + # Skip the file header: + fheader = zef_file.read(sizeFileHeader) + if len(fheader) != sizeFileHeader: + raise BadZipFile("Truncated file header") + fheader = struct.unpack(structFileHeader, fheader) + if fheader[_FH_SIGNATURE] != stringFileHeader: + raise BadZipFile("Bad magic number for file header") + + fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1) + + if zinfo.flag_bits & _MASK_COMPRESSED_PATCH: + # Zip 2.7: compressed patched data + raise NotImplementedError("compressed patched data (flag bit 5)") + + if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION: + # strong encryption + raise NotImplementedError("strong encryption (flag bit 6)") + + if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME: + # UTF-8 filename + fname_str = fname.decode("utf-8") + else: + fname_str = fname.decode(self.metadata_encoding or "cp437") + + if fname_str != zinfo.orig_filename: + raise BadZipFile( + 'File name in directory %r and header %r differ.' + % (zinfo.orig_filename, fname)) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED + if is_encrypted: + if not pwd: + pwd = self.pwd + if pwd and not isinstance(pwd, bytes): + raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__) + if not pwd: + raise RuntimeError("File %r is encrypted, password " + "required for extraction" % name) + else: + pwd = None + + return ZipExtFile(zef_file, mode, zinfo, pwd, True) + except: + zef_file.close() + raise \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..baaee7125785ddad4283601ae509c461032a8a15 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "moge" +version = "2.0.0" +description = "MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision" +readme = "README.md" +license = {text = "MIT"} +dependencies = [ + "click", + "opencv-python", + "scipy", + "matplotlib", + "trimesh", + "pillow", + "huggingface_hub", + "numpy", + "torch>=2.0.0", + "torchvision", + "gradio", + "utils3d @ git+https://github.com/EasternJournalist/utils3d.git@c5daf6f6c244d251f252102d09e9b7bcef791a38" +] +requires-python = ">=3.9" + +[project.urls] +Homepage = "https://github.com/microsoft/MoGe" + +[tool.setuptools.packages.find] +where = ["."] +include = ["moge*"] + +[project.scripts] +moge = "moge.scripts.cli:main" \ No newline at end of file diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..deb3aa62afbda00a7c7413b9eefa6f0ec18fb72b --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,10 @@ +{ + "include": [ + "moge", + "scripts", + "baselines" + ], + "ignore": [ + "**" + ] +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3dc26ac2a12fb1532417d19c56cc6e8bf8dd2946 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +# The versions are not specified since MoGe should be compatible with most versions of the packages. +# If incompatibilities are found, consider upgrading to latest versions or installing the following recommended version of the package. +torch # >= 2.0.0 +torchvision +gradio # ==2.8.13 +click # ==8.1.7 +opencv-python # ==4.10.0.84 +scipy # ==1.14.1 +matplotlib # ==3.9.2 +trimesh # ==4.5.1 +pillow # ==10.4.0 +huggingface_hub # ==0.25.2 +git+https://github.com/EasternJournalist/utils3d.git@c5daf6f6c244d251f252102d09e9b7bcef791a38 +