diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..a6839e09fb1af6a7c2937badd043cbceda3b2a73 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,90 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +# SCM syntax highlighting & preventing 3-way merges +pixi.lock merge=binary linguist-language=YAML linguist-generated=true +examples/kitchen/images/00.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/01.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/02.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/03.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/04.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/05.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/06.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/07.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/08.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/09.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/10.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/11.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/12.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/13.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/14.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/15.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/16.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/17.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/18.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/19.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/20.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/21.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/22.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/23.png filter=lfs diff=lfs merge=lfs -text +examples/kitchen/images/24.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/000.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/001.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/002.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/003.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/004.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/005.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/006.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/007.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/008.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/009.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/010.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/011.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/012.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/013.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/014.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/015.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/016.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/017.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/018.png filter=lfs diff=lfs merge=lfs -text +examples/llff_fern/images/019.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/000.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/001.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/002.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/003.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/004.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/005.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/006.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/007.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/008.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/009.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/010.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/011.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/012.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/013.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/014.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/015.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/016.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/017.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/018.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/019.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/020.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/021.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/022.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/023.png filter=lfs diff=lfs merge=lfs -text +examples/llff_flower/images/024.png filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_1.png filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_2.jpg filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_3.jpg filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_4.jpg filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_5.jpg filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_6.jpg filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_7.jpg filter=lfs diff=lfs merge=lfs -text +examples/room/images/no_overlap_8.jpg filter=lfs diff=lfs merge=lfs -text +examples/single_cartoon/images/model_was_never_trained_on_single_image_or_cartoon.jpg filter=lfs diff=lfs merge=lfs -text +examples/single_oil_painting/images/model_was_never_trained_on_single_image_or_oil_painting.png filter=lfs diff=lfs merge=lfs -text +examples/videos/Colosseum.mp4 filter=lfs diff=lfs merge=lfs -text +examples/videos/fern.mp4 filter=lfs diff=lfs merge=lfs -text +examples/videos/great_wall.mp4 filter=lfs diff=lfs merge=lfs -text +examples/videos/kitchen.mp4 filter=lfs diff=lfs merge=lfs -text +examples/videos/pyramid.mp4 filter=lfs diff=lfs merge=lfs -text +examples/videos/room.mp4 filter=lfs diff=lfs merge=lfs -text +examples/videos/single_cartoon.mp4 filter=lfs diff=lfs merge=lfs -text +examples/videos/single_oil_painting.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9d7352ead3cb8f5add2532c9e502dbf5a64d921d --- /dev/null +++ b/.gitignore @@ -0,0 +1,149 @@ +.hydra/ +output/ +ckpt/ +# Byte-compiled / optimized / DLL files +__pycache__/ +**/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Profiling data +.prof + +# Folder specific to your needs +**/tmp/ +**/outputs/skyseg.onnx +skyseg.onnx + +# pixi environments +.pixi +*.egg-info diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..3232ed665566ec047ce55a929db1581dbda266a1 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..72baaa2eb86da6050a43c1ea553c095932a5b939 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to vggt +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to vggt, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..e395ca3e2cdebf48a6375a3c1022d10caabba7db --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the β€œLicensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md index 1843148b7f5138662ccfd92a162a5dba25fbe38b..7b55d5b2be7f624bdcc3be8c94ae0803dbc375fb 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,288 @@ --- -title: Vggt -emoji: 🐒 -colorFrom: blue -colorTo: yellow +title: vggt +app_file: demo_gradio.py sdk: gradio sdk_version: 5.34.2 -app_file: app.py -pinned: false --- +
+

VGGT: Visual Geometry Grounded Transformer

-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + + Paper PDF + +arXiv +Project Page + + + +**[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**; **[Meta AI](https://ai.facebook.com/research/)** + + +[Jianyuan Wang](https://jytime.github.io/), [Minghao Chen](https://silent-chen.github.io/), [Nikita Karaev](https://nikitakaraevv.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/), [David Novotny](https://d-novotny.github.io/) +
+ +```bibtex +@inproceedings{wang2025vggt, + title={VGGT: Visual Geometry Grounded Transformer}, + author={Wang, Jianyuan and Chen, Minghao and Karaev, Nikita and Vedaldi, Andrea and Rupprecht, Christian and Novotny, David}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2025} +} +``` + +## Updates +- [June 13, 2025] Honored to receive the Best Paper Award at CVPR 2025! Apologies if I’m slow to respond to queries or GitHub issues these days. If you’re interested, our oral presentation is available [here](https://docs.google.com/presentation/d/1JVuPnuZx6RgAy-U5Ezobg73XpBi7FrOh/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true). (Note: it’s shared in .pptx format with animations β€” quite large, but feel free to use it as a template if helpful.) + + +- [June 2, 2025] Added a script to run VGGT and save predictions in COLMAP format, with bundle adjustment support optional. The saved COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) or other NeRF/Gaussian splatting libraries. + + +- [May 3, 2025] Evaluation code for reproducing our camera pose estimation results on Co3D is now available in the [evaluation](https://github.com/facebookresearch/vggt/tree/evaluation) branch. + + +- [Apr 13, 2025] Training code is being gradually cleaned and uploaded to the [training](https://github.com/facebookresearch/vggt/tree/training) branch. It will be merged into the main branch once finalized. + +## Overview + +Visual Geometry Grounded Transformer (VGGT, CVPR 2025) is a feed-forward neural network that directly infers all key 3D attributes of a scene, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks, **from one, a few, or hundreds of its views, within seconds**. + + +## Quick Start + +First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub). + +```bash +git clone git@github.com:facebookresearch/vggt.git +cd vggt +pip install -r requirements.txt +``` + +Alternatively, you can install VGGT as a package (click here for details). + + +Now, try the model with just a few lines of code: + +```python +import torch +from vggt.models.vggt import VGGT +from vggt.utils.load_fn import load_and_preprocess_images + +device = "cuda" if torch.cuda.is_available() else "cpu" +# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) +dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + +# Initialize the model and load the pretrained weights. +# This will automatically download the model weights the first time it's run, which may take a while. +model = VGGT.from_pretrained("facebook/VGGT-1B").to(device) + +# Load and preprocess example images (replace with your own image paths) +image_names = ["path/to/imageA.png", "path/to/imageB.png", "path/to/imageC.png"] +images = load_and_preprocess_images(image_names).to(device) + +with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + # Predict attributes including cameras, depth maps, and point maps. + predictions = model(images) +``` + +The model weights will be automatically downloaded from Hugging Face. If you encounter issues such as slow loading, you can manually download them [here](https://huggingface.co/facebook/VGGT-1B/blob/main/model.pt) and load, or: + +```python +model = VGGT() +_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" +model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) +``` + +## Detailed Usage + +
+Click to expand + +You can also optionally choose which attributes (branches) to predict, as shown below. This achieves the same result as the example above. This example uses a batch size of 1 (processing a single scene), but it naturally works for multiple scenes. + +```python +from vggt.utils.pose_enc import pose_encoding_to_extri_intri +from vggt.utils.geometry import unproject_depth_map_to_point_map + +with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + images = images[None] # add batch dimension + aggregated_tokens_list, ps_idx = model.aggregator(images) + + # Predict Cameras + pose_enc = model.camera_head(aggregated_tokens_list)[-1] + # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) + extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:]) + + # Predict Depth Maps + depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx) + + # Predict Point Maps + point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx) + + # Construct 3D Points from Depth Maps and Cameras + # which usually leads to more accurate 3D points than point map branch + point_map_by_unprojection = unproject_depth_map_to_point_map(depth_map.squeeze(0), + extrinsic.squeeze(0), + intrinsic.squeeze(0)) + + # Predict Tracks + # choose your own points to track, with shape (N, 2) for one scene + query_points = torch.FloatTensor([[100.0, 200.0], + [60.72, 259.94]]).to(device) + track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None]) +``` + + +Furthermore, if certain pixels in the input frames are unwanted (e.g., reflective surfaces, sky, or water), you can simply mask them by setting the corresponding pixel values to 0 or 1. Precise segmentation masks aren't necessary - simple bounding box masks work effectively (check this [issue](https://github.com/facebookresearch/vggt/issues/47) for an example). + +
+ + +## Interactive Demo + +We provide multiple ways to visualize your 3D reconstructions. Before using these visualization tools, install the required dependencies: + +```bash +pip install -r requirements_demo.txt +``` + +### Interactive 3D Visualization + +**Please note:** VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, independent of VGGT's processing time. The visualization is slow especially when the number of images is large. + + +#### Gradio Web Interface + +Our Gradio-based interface allows you to upload images/videos, run reconstruction, and interactively explore the 3D scene in your browser. You can launch this in your local machine or try it on [Hugging Face](https://huggingface.co/spaces/facebook/vggt). + + +```bash +python demo_gradio.py +``` + +
+Click to preview the Gradio interactive interface + +![Gradio Web Interface Preview](https://jytime.github.io/data/vggt_hf_demo_screen.png) +
+ + +#### Viser 3D Viewer + +Run the following command to run reconstruction and visualize the point clouds in viser. Note this script requires a path to a folder containing images. It assumes only image files under the folder. You can set `--use_point_map` to use the point cloud from the point map branch, instead of the depth-based point cloud. + +```bash +python demo_viser.py --image_folder path/to/your/images/folder +``` + +## Exporting to COLMAP Format + +We also support exporting VGGT's predictions directly to COLMAP format, by: + +```bash +# Feedforward prediction only +python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ + +# With bundle adjustment +python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba + +# Run with bundle adjustment using reduced parameters for faster processing +# Reduces max_query_pts from 4096 (default) to 2048 and query_frame_num from 8 (default) to 5 +# Trade-off: Faster execution but potentially less robust reconstruction in complex scenes (you may consider setting query_frame_num equal to your total number of images) +# See demo_colmap.py for additional bundle adjustment configuration options +python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba --max_query_pts=2048 --query_frame_num=5 +``` + +Please ensure that the images are stored in `/YOUR/SCENE_DIR/images/`. This folder should contain only the images. Check the examples folder for the desired data structure. + +The reconstruction result (camera parameters and 3D points) will be automatically saved under `/YOUR/SCENE_DIR/sparse/` in the COLMAP format, such as: + +``` +SCENE_DIR/ +β”œβ”€β”€ images/ +└── sparse/ + β”œβ”€β”€ cameras.bin + β”œβ”€β”€ images.bin + └── points3D.bin +``` + +## Integration with Gaussian Splatting + + +The exported COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) for Gaussian Splatting training. Install `gsplat` following their official instructions (we recommend `gsplat==1.3.0`): + +An example command to train the model is: +``` +cd gsplat +python examples/simple_trainer.py default --data_factor 1 --data_dir /YOUR/SCENE_DIR/ --result_dir /YOUR/RESULT_DIR/ +``` + + + +## Zero-shot Single-view Reconstruction + +Our model shows surprisingly good performance on single-view reconstruction, although it was never trained for this task. The model does not need to duplicate the single-view image to a pair, instead, it can directly infer the 3D structure from the tokens of the single view image. Feel free to try it with our demos above, which naturally works for single-view reconstruction. + + +We did not quantitatively test monocular depth estimation performance ourselves, but [@kabouzeid](https://github.com/kabouzeid) generously provided a comparison of VGGT to recent methods [here](https://github.com/facebookresearch/vggt/issues/36). VGGT shows competitive or better results compared to state-of-the-art monocular approaches such as DepthAnything v2 or MoGe, despite never being explicitly trained for single-view tasks. + + + +## Runtime and GPU Memory + +We benchmark the runtime and GPU memory usage of VGGT's aggregator on a single NVIDIA H100 GPU across various input sizes. + +| **Input Frames** | 1 | 2 | 4 | 8 | 10 | 20 | 50 | 100 | 200 | +|:----------------:|:-:|:-:|:-:|:-:|:--:|:--:|:--:|:---:|:---:| +| **Time (s)** | 0.04 | 0.05 | 0.07 | 0.11 | 0.14 | 0.31 | 1.04 | 3.12 | 8.75 | +| **Memory (GB)** | 1.88 | 2.07 | 2.45 | 3.23 | 3.63 | 5.58 | 11.41 | 21.15 | 40.63 | + +Note that these results were obtained using Flash Attention 3, which is faster than the default Flash Attention 2 implementation while maintaining almost the same memory usage. Feel free to compile Flash Attention 3 from source to get better performance. + + +## Research Progression + +Our work builds upon a series of previous research projects. If you're interested in understanding how our research evolved, check out our previous works: + + + + + + + + + + + + + + + + + + +
+ Deep SfM Revisited + ──┐
+ PoseDiffusion + ─────► + VGGSfM ──► + VGGT +
+ CoTracker + β”€β”€β”˜
+ + +## Acknowledgements + +Thanks to these great repositories: [PoseDiffusion](https://github.com/facebookresearch/PoseDiffusion), [VGGSfM](https://github.com/facebookresearch/vggsfm), [CoTracker](https://github.com/facebookresearch/co-tracker), [DINOv2](https://github.com/facebookresearch/dinov2), [Dust3r](https://github.com/naver/dust3r), [Moge](https://github.com/microsoft/moge), [PyTorch3D](https://github.com/facebookresearch/pytorch3d), [Sky Segmentation](https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing), [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), [Metric3D](https://github.com/YvanYin/Metric3D) and many other inspiring works in the community. + +## Checklist + +- [ ] Release the training code +- [ ] Release VGGT-500M and VGGT-200M + + +## License +See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available. diff --git a/demo_colmap.py b/demo_colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..836af1722844ba637a3a0a23a94a54b3f5e4550f --- /dev/null +++ b/demo_colmap.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +import numpy as np +import glob +import os +import copy +import torch +import torch.nn.functional as F + +# Configure CUDA settings +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.deterministic = False + +import argparse +from pathlib import Path +import trimesh +import pycolmap + + +from vggt.models.vggt import VGGT +from vggt.utils.load_fn import load_and_preprocess_images_square +from vggt.utils.pose_enc import pose_encoding_to_extri_intri +from vggt.utils.geometry import unproject_depth_map_to_point_map +from vggt.utils.helper import create_pixel_coordinate_grid, randomly_limit_trues +from vggt.dependency.track_predict import predict_tracks +from vggt.dependency.np_to_pycolmap import batch_np_matrix_to_pycolmap, batch_np_matrix_to_pycolmap_wo_track + + +# TODO: add support for masks +# TODO: add iterative BA +# TODO: add support for radial distortion, which needs extra_params +# TODO: test with more cases +# TODO: test different camera types + + +def parse_args(): + parser = argparse.ArgumentParser(description="VGGT Demo") + parser.add_argument("--scene_dir", type=str, required=True, help="Directory containing the scene images") + parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") + parser.add_argument("--use_ba", action="store_true", default=False, help="Use BA for reconstruction") + ######### BA parameters ######### + parser.add_argument( + "--max_reproj_error", type=float, default=8.0, help="Maximum reprojection error for reconstruction" + ) + parser.add_argument("--shared_camera", action="store_true", default=False, help="Use shared camera for all images") + parser.add_argument("--camera_type", type=str, default="SIMPLE_PINHOLE", help="Camera type for reconstruction") + parser.add_argument("--vis_thresh", type=float, default=0.2, help="Visibility threshold for tracks") + parser.add_argument("--query_frame_num", type=int, default=8, help="Number of frames to query") + parser.add_argument("--max_query_pts", type=int, default=4096, help="Maximum number of query points") + parser.add_argument( + "--fine_tracking", action="store_true", default=True, help="Use fine tracking (slower but more accurate)" + ) + parser.add_argument( + "--conf_thres_value", type=float, default=5.0, help="Confidence threshold value for depth filtering (wo BA)" + ) + return parser.parse_args() + + +def run_VGGT(model, images, dtype, resolution=518): + # images: [B, 3, H, W] + + assert len(images.shape) == 4 + assert images.shape[1] == 3 + + # hard-coded to use 518 for VGGT + images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False) + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + images = images[None] # add batch dimension + aggregated_tokens_list, ps_idx = model.aggregator(images) + + # Predict Cameras + pose_enc = model.camera_head(aggregated_tokens_list)[-1] + # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) + extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:]) + # Predict Depth Maps + depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx) + + extrinsic = extrinsic.squeeze(0).cpu().numpy() + intrinsic = intrinsic.squeeze(0).cpu().numpy() + depth_map = depth_map.squeeze(0).cpu().numpy() + depth_conf = depth_conf.squeeze(0).cpu().numpy() + return extrinsic, intrinsic, depth_map, depth_conf + + +def demo_fn(args): + # Print configuration + print("Arguments:", vars(args)) + + # Set seed for reproducibility + np.random.seed(args.seed) + torch.manual_seed(args.seed) + random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) # for multi-GPU + print(f"Setting seed as: {args.seed}") + + # Set device and dtype + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + print(f"Using dtype: {dtype}") + + # Run VGGT for camera and depth estimation + model = VGGT() + _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" + model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) + model.eval() + model = model.to(device) + print(f"Model loaded") + + # Get image paths and preprocess them + image_dir = os.path.join(args.scene_dir, "images") + image_path_list = glob.glob(os.path.join(image_dir, "*")) + if len(image_path_list) == 0: + raise ValueError(f"No images found in {image_dir}") + base_image_path_list = [os.path.basename(path) for path in image_path_list] + + # Load images and original coordinates + # Load Image in 1024, while running VGGT with 518 + vggt_fixed_resolution = 518 + img_load_resolution = 1024 + + images, original_coords = load_and_preprocess_images_square(image_path_list, img_load_resolution) + images = images.to(device) + original_coords = original_coords.to(device) + print(f"Loaded {len(images)} images from {image_dir}") + + # Run VGGT to estimate camera and depth + # Run with 518x518 images + extrinsic, intrinsic, depth_map, depth_conf = run_VGGT(model, images, dtype, vggt_fixed_resolution) + points_3d = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic) + + if args.use_ba: + image_size = np.array(images.shape[-2:]) + scale = img_load_resolution / vggt_fixed_resolution + shared_camera = args.shared_camera + + with torch.cuda.amp.autocast(dtype=dtype): + # Predicting Tracks + # Using VGGSfM tracker instead of VGGT tracker for efficiency + # VGGT tracker requires multiple backbone runs to query different frames (this is a problem caused by the training process) + # Will be fixed in VGGT v2 + + # You can also change the pred_tracks to tracks from any other methods + # e.g., from COLMAP, from CoTracker, or by chaining 2D matches from Lightglue/LoFTR. + pred_tracks, pred_vis_scores, pred_confs, points_3d, points_rgb = predict_tracks( + images, + conf=depth_conf, + points_3d=points_3d, + masks=None, + max_query_pts=args.max_query_pts, + query_frame_num=args.query_frame_num, + keypoint_extractor="aliked+sp", + fine_tracking=args.fine_tracking, + ) + + torch.cuda.empty_cache() + + # rescale the intrinsic matrix from 518 to 1024 + intrinsic[:, :2, :] *= scale + track_mask = pred_vis_scores > args.vis_thresh + + # TODO: radial distortion, iterative BA, masks + reconstruction, valid_track_mask = batch_np_matrix_to_pycolmap( + points_3d, + extrinsic, + intrinsic, + pred_tracks, + image_size, + masks=track_mask, + max_reproj_error=args.max_reproj_error, + shared_camera=shared_camera, + camera_type=args.camera_type, + points_rgb=points_rgb, + ) + + if reconstruction is None: + raise ValueError("No reconstruction can be built with BA") + + # Bundle Adjustment + ba_options = pycolmap.BundleAdjustmentOptions() + pycolmap.bundle_adjustment(reconstruction, ba_options) + + reconstruction_resolution = img_load_resolution + else: + conf_thres_value = args.conf_thres_value + max_points_for_colmap = 100000 # randomly sample 3D points + shared_camera = False # in the feedforward manner, we do not support shared camera + camera_type = "PINHOLE" # in the feedforward manner, we only support PINHOLE camera + + image_size = np.array([vggt_fixed_resolution, vggt_fixed_resolution]) + num_frames, height, width, _ = points_3d.shape + + points_rgb = F.interpolate( + images, size=(vggt_fixed_resolution, vggt_fixed_resolution), mode="bilinear", align_corners=False + ) + points_rgb = (points_rgb.cpu().numpy() * 255).astype(np.uint8) + points_rgb = points_rgb.transpose(0, 2, 3, 1) + + # (S, H, W, 3), with x, y coordinates and frame indices + points_xyf = create_pixel_coordinate_grid(num_frames, height, width) + + conf_mask = depth_conf >= conf_thres_value + # at most writing 100000 3d points to colmap reconstruction object + conf_mask = randomly_limit_trues(conf_mask, max_points_for_colmap) + + points_3d = points_3d[conf_mask] + points_xyf = points_xyf[conf_mask] + points_rgb = points_rgb[conf_mask] + + print("Converting to COLMAP format") + reconstruction = batch_np_matrix_to_pycolmap_wo_track( + points_3d, + points_xyf, + points_rgb, + extrinsic, + intrinsic, + image_size, + shared_camera=shared_camera, + camera_type=camera_type, + ) + + reconstruction_resolution = vggt_fixed_resolution + + reconstruction = rename_colmap_recons_and_rescale_camera( + reconstruction, + base_image_path_list, + original_coords.cpu().numpy(), + img_size=reconstruction_resolution, + shift_point2d_to_original_res=True, + shared_camera=shared_camera, + ) + + print(f"Saving reconstruction to {args.scene_dir}/sparse") + sparse_reconstruction_dir = os.path.join(args.scene_dir, "sparse") + os.makedirs(sparse_reconstruction_dir, exist_ok=True) + reconstruction.write(sparse_reconstruction_dir) + + # Save point cloud for fast visualization + trimesh.PointCloud(points_3d, colors=points_rgb).export(os.path.join(args.scene_dir, "sparse/points.ply")) + + return True + + +def rename_colmap_recons_and_rescale_camera( + reconstruction, image_paths, original_coords, img_size, shift_point2d_to_original_res=False, shared_camera=False +): + rescale_camera = True + + for pyimageid in reconstruction.images: + # Reshaped the padded&resized image to the original size + # Rename the images to the original names + pyimage = reconstruction.images[pyimageid] + pycamera = reconstruction.cameras[pyimage.camera_id] + pyimage.name = image_paths[pyimageid - 1] + + if rescale_camera: + # Rescale the camera parameters + pred_params = copy.deepcopy(pycamera.params) + + real_image_size = original_coords[pyimageid - 1, -2:] + resize_ratio = max(real_image_size) / img_size + pred_params = pred_params * resize_ratio + real_pp = real_image_size / 2 + pred_params[-2:] = real_pp # center of the image + + pycamera.params = pred_params + pycamera.width = real_image_size[0] + pycamera.height = real_image_size[1] + + if shift_point2d_to_original_res: + # Also shift the point2D to original resolution + top_left = original_coords[pyimageid - 1, :2] + + for point2D in pyimage.points2D: + point2D.xy = (point2D.xy - top_left) * resize_ratio + + if shared_camera: + # If shared_camera, all images share the same camera + # no need to rescale any more + rescale_camera = False + + return reconstruction + + +if __name__ == "__main__": + args = parse_args() + with torch.no_grad(): + demo_fn(args) + + +# Work in Progress (WIP) + +""" +VGGT Runner Script +================= + +A script to run the VGGT model for 3D reconstruction from image sequences. + +Directory Structure +------------------ +Input: + input_folder/ + └── images/ # Source images for reconstruction + +Output: + output_folder/ + β”œβ”€β”€ images/ + β”œβ”€β”€ sparse/ # Reconstruction results + β”‚ β”œβ”€β”€ cameras.bin # Camera parameters (COLMAP format) + β”‚ β”œβ”€β”€ images.bin # Pose for each image (COLMAP format) + β”‚ β”œβ”€β”€ points3D.bin # 3D points (COLMAP format) + β”‚ └── points.ply # Point cloud visualization file + └── visuals/ # Visualization outputs TODO + +Key Features +----------- +β€’ Dual-mode Support: Run reconstructions using either VGGT or VGGT+BA +β€’ Resolution Preservation: Maintains original image resolution in camera parameters and tracks +β€’ COLMAP Compatibility: Exports results in standard COLMAP sparse reconstruction format +""" diff --git a/demo_gradio.py b/demo_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..466a5ff48a3976a1ec8566e3afa2499910fc5aa6 --- /dev/null +++ b/demo_gradio.py @@ -0,0 +1,690 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import cv2 +import torch +import numpy as np +import gradio as gr +import sys +import shutil +from datetime import datetime +import glob +import gc +import time + +sys.path.append("vggt/") + +from visual_util import predictions_to_glb +from vggt.models.vggt import VGGT +from vggt.utils.load_fn import load_and_preprocess_images +from vggt.utils.pose_enc import pose_encoding_to_extri_intri +from vggt.utils.geometry import unproject_depth_map_to_point_map + +device = "cuda" if torch.cuda.is_available() else "cpu" + +print("Initializing and loading VGGT model...") +# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model + +model = VGGT() +_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" +model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) + + +model.eval() +model = model.to(device) + + +# ------------------------------------------------------------------------- +# 1) Core model inference +# ------------------------------------------------------------------------- +def run_model(target_dir, model) -> dict: + """ + Run the VGGT model on images in the 'target_dir/images' folder and return predictions. + """ + print(f"Processing images from {target_dir}") + + # Device check + device = "cuda" if torch.cuda.is_available() else "cpu" + if not torch.cuda.is_available(): + raise ValueError("CUDA is not available. Check your environment.") + + # Move model to device + model = model.to(device) + model.eval() + + # Load and preprocess images + image_names = glob.glob(os.path.join(target_dir, "images", "*")) + image_names = sorted(image_names) + print(f"Found {len(image_names)} images") + if len(image_names) == 0: + raise ValueError("No images found. Check your upload.") + + images = load_and_preprocess_images(image_names).to(device) + print(f"Preprocessed images shape: {images.shape}") + + # Run inference + print("Running inference...") + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + predictions = model(images) + + # Convert pose encoding to extrinsic and intrinsic matrices + print("Converting pose encoding to extrinsic and intrinsic matrices...") + extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) + predictions["extrinsic"] = extrinsic + predictions["intrinsic"] = intrinsic + + # Convert tensors to numpy + for key in predictions.keys(): + if isinstance(predictions[key], torch.Tensor): + predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension + + # Generate world points from depth map + print("Computing world points from depth map...") + depth_map = predictions["depth"] # (S, H, W, 1) + world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"]) + predictions["world_points_from_depth"] = world_points + + # Clean up + torch.cuda.empty_cache() + return predictions + + +# ------------------------------------------------------------------------- +# 2) Handle uploaded video/images --> produce target_dir + images +# ------------------------------------------------------------------------- +def handle_uploads(input_video, input_images): + """ + Create a new 'target_dir' + 'images' subfolder, and place user-uploaded + images or extracted frames from video into it. Return (target_dir, image_paths). + """ + start_time = time.time() + gc.collect() + torch.cuda.empty_cache() + + # Create a unique folder name + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + target_dir = f"input_images_{timestamp}" + target_dir_images = os.path.join(target_dir, "images") + + # Clean up if somehow that folder already exists + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir) + os.makedirs(target_dir_images) + + image_paths = [] + + # --- Handle images --- + if input_images is not None: + for file_data in input_images: + if isinstance(file_data, dict) and "name" in file_data: + file_path = file_data["name"] + else: + file_path = file_data + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) + shutil.copy(file_path, dst_path) + image_paths.append(dst_path) + + # --- Handle video --- + if input_video is not None: + if isinstance(input_video, dict) and "name" in input_video: + video_path = input_video["name"] + else: + video_path = input_video + + vs = cv2.VideoCapture(video_path) + fps = vs.get(cv2.CAP_PROP_FPS) + frame_interval = int(fps * 1) # 1 frame/sec + + count = 0 + video_frame_num = 0 + while True: + gotit, frame = vs.read() + if not gotit: + break + count += 1 + if count % frame_interval == 0: + image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") + cv2.imwrite(image_path, frame) + image_paths.append(image_path) + video_frame_num += 1 + + # Sort final images for gallery + image_paths = sorted(image_paths) + + end_time = time.time() + print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds") + return target_dir, image_paths + + +# ------------------------------------------------------------------------- +# 3) Update gallery on upload +# ------------------------------------------------------------------------- +def update_gallery_on_upload(input_video, input_images): + """ + Whenever user uploads or changes files, immediately handle them + and show in the gallery. Return (target_dir, image_paths). + If nothing is uploaded, returns "None" and empty list. + """ + if not input_video and not input_images: + return None, None, None, None + target_dir, image_paths = handle_uploads(input_video, input_images) + return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing." + + +# ------------------------------------------------------------------------- +# 4) Reconstruction: uses the target_dir plus any viz parameters +# ------------------------------------------------------------------------- +def gradio_demo( + target_dir, + conf_thres=3.0, + frame_filter="All", + mask_black_bg=False, + mask_white_bg=False, + show_cam=True, + mask_sky=False, + prediction_mode="Pointmap Regression", +): + """ + Perform reconstruction using the already-created target_dir/images. + """ + if not os.path.isdir(target_dir) or target_dir == "None": + return None, "No valid target directory found. Please upload first.", None, None + + start_time = time.time() + gc.collect() + torch.cuda.empty_cache() + + # Prepare frame_filter dropdown + target_dir_images = os.path.join(target_dir, "images") + all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] + all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] + frame_filter_choices = ["All"] + all_files + + print("Running run_model...") + with torch.no_grad(): + predictions = run_model(target_dir, model) + + # Save predictions + prediction_save_path = os.path.join(target_dir, "predictions.npz") + np.savez(prediction_save_path, **predictions) + + # Handle None frame_filter + if frame_filter is None: + frame_filter = "All" + + # Build a GLB file name + glbfile = os.path.join( + target_dir, + f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", + ) + + # Convert predictions to GLB + glbscene = predictions_to_glb( + predictions, + conf_thres=conf_thres, + filter_by_frames=frame_filter, + mask_black_bg=mask_black_bg, + mask_white_bg=mask_white_bg, + show_cam=show_cam, + mask_sky=mask_sky, + target_dir=target_dir, + prediction_mode=prediction_mode, + ) + glbscene.export(file_obj=glbfile) + + # Cleanup + del predictions + gc.collect() + torch.cuda.empty_cache() + + end_time = time.time() + print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") + log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." + + return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True) + + +# ------------------------------------------------------------------------- +# 5) Helper functions for UI resets + re-visualization +# ------------------------------------------------------------------------- +def clear_fields(): + """ + Clears the 3D viewer, the stored target_dir, and empties the gallery. + """ + return None + + +def update_log(): + """ + Display a quick log message while waiting. + """ + return "Loading and Reconstructing..." + + +def update_visualization( + target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example +): + """ + Reload saved predictions from npz, create (or reuse) the GLB for new parameters, + and return it for the 3D viewer. If is_example == "True", skip. + """ + + # If it's an example click, skip as requested + if is_example == "True": + return None, "No reconstruction available. Please click the Reconstruct button first." + + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return None, "No reconstruction available. Please click the Reconstruct button first." + + predictions_path = os.path.join(target_dir, "predictions.npz") + if not os.path.exists(predictions_path): + return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first." + + key_list = [ + "pose_enc", + "depth", + "depth_conf", + "world_points", + "world_points_conf", + "images", + "extrinsic", + "intrinsic", + "world_points_from_depth", + ] + + loaded = np.load(predictions_path) + predictions = {key: np.array(loaded[key]) for key in key_list} + + glbfile = os.path.join( + target_dir, + f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", + ) + + if not os.path.exists(glbfile): + glbscene = predictions_to_glb( + predictions, + conf_thres=conf_thres, + filter_by_frames=frame_filter, + mask_black_bg=mask_black_bg, + mask_white_bg=mask_white_bg, + show_cam=show_cam, + mask_sky=mask_sky, + target_dir=target_dir, + prediction_mode=prediction_mode, + ) + glbscene.export(file_obj=glbfile) + + return glbfile, "Updating Visualization" + + +# ------------------------------------------------------------------------- +# Example images +# ------------------------------------------------------------------------- + +great_wall_video = "examples/videos/great_wall.mp4" +colosseum_video = "examples/videos/Colosseum.mp4" +room_video = "examples/videos/room.mp4" +kitchen_video = "examples/videos/kitchen.mp4" +fern_video = "examples/videos/fern.mp4" +single_cartoon_video = "examples/videos/single_cartoon.mp4" +single_oil_painting_video = "examples/videos/single_oil_painting.mp4" +pyramid_video = "examples/videos/pyramid.mp4" + + +# ------------------------------------------------------------------------- +# 6) Build Gradio UI +# ------------------------------------------------------------------------- +theme = gr.themes.Ocean() +theme.set( + checkbox_label_background_fill_selected="*button_primary_background_fill", + checkbox_label_text_color_selected="*button_primary_text_color", +) + +with gr.Blocks( + theme=theme, + css=""" + .custom-log * { + font-style: italic; + font-size: 22px !important; + background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); + -webkit-background-clip: text; + background-clip: text; + font-weight: bold !important; + color: transparent !important; + text-align: center !important; + } + + .example-log * { + font-style: italic; + font-size: 16px !important; + background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; + } + + #my_radio .wrap { + display: flex; + flex-wrap: nowrap; + justify-content: center; + align-items: center; + } + + #my_radio .wrap label { + display: flex; + width: 50%; + justify-content: center; + align-items: center; + margin: 0; + padding: 10px 0; + box-sizing: border-box; + } + """, +) as demo: + # Instead of gr.State, we use a hidden Textbox: + is_example = gr.Textbox(label="is_example", visible=False, value="None") + num_images = gr.Textbox(label="num_images", visible=False, value="None") + + gr.HTML( + """ +

πŸ›οΈ VGGT: Visual Geometry Grounded Transformer

+

+ πŸ™ GitHub Repository | + Project Page +

+ +
+

Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.

+ +

Getting Started:

+
    +
  1. Upload Your Data: Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).
  2. +
  3. Preview: Your uploaded images will appear in the gallery on the left.
  4. +
  5. Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction process.
  6. +
  7. Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.
  8. +
  9. + Adjust Visualization (Optional): + After reconstruction, you can fine-tune the visualization using the options below +
    + (click to expand): +
      +
    • Confidence Threshold: Adjust the filtering of points based on confidence.
    • +
    • Show Points from Frame: Select specific frames to display in the point cloud.
    • +
    • Show Camera: Toggle the display of estimated camera positions.
    • +
    • Filter Sky / Filter Black Background: Remove sky or black-background points.
    • +
    • Select a Prediction Mode: Choose between "Depthmap and Camera Branch" or "Pointmap Branch."
    • +
    +
    +
  10. +
+

Please note: VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time.

+
+ """ + ) + + target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") + + with gr.Row(): + with gr.Column(scale=2): + input_video = gr.Video(label="Upload Video", interactive=True) + input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) + + image_gallery = gr.Gallery( + label="Preview", + columns=4, + height="300px", + show_download_button=True, + object_fit="contain", + preview=True, + ) + + with gr.Column(scale=4): + with gr.Column(): + gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**") + log_output = gr.Markdown( + "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"] + ) + reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5) + + with gr.Row(): + submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") + clear_btn = gr.ClearButton( + [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery], + scale=1, + ) + + with gr.Row(): + prediction_mode = gr.Radio( + ["Depthmap and Camera Branch", "Pointmap Branch"], + label="Select a Prediction Mode", + value="Depthmap and Camera Branch", + scale=1, + elem_id="my_radio", + ) + + with gr.Row(): + conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)") + frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") + with gr.Column(): + show_cam = gr.Checkbox(label="Show Camera", value=True) + mask_sky = gr.Checkbox(label="Filter Sky", value=False) + mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False) + mask_white_bg = gr.Checkbox(label="Filter White Background", value=False) + + # ---------------------- Examples section ---------------------- + examples = [ + [colosseum_video, "22", None, 20.0, False, False, True, False, "Depthmap and Camera Branch", "True"], + [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"], + [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"], + [single_oil_painting_video, "1", None, 20.0, False, False, True, True, "Depthmap and Camera Branch", "True"], + [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"], + [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"], + [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"], + ] + + def example_pipeline( + input_video, + num_images_str, + input_images, + conf_thres, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example_str, + ): + """ + 1) Copy example images to new target_dir + 2) Reconstruct + 3) Return model3D + logs + new_dir + updated dropdown + gallery + We do NOT return is_example. It's just an input. + """ + target_dir, image_paths = handle_uploads(input_video, input_images) + # Always use "All" for frame_filter in examples + frame_filter = "All" + glbfile, log_msg, dropdown = gradio_demo( + target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode + ) + return glbfile, log_msg, target_dir, dropdown, image_paths + + gr.Markdown("Click any row to load an example.", elem_classes=["example-log"]) + + gr.Examples( + examples=examples, + inputs=[ + input_video, + num_images, + input_images, + conf_thres, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery], + fn=example_pipeline, + cache_examples=False, + examples_per_page=50, + ) + + # ------------------------------------------------------------------------- + # "Reconstruct" button logic: + # - Clear fields + # - Update log + # - gradio_demo(...) with the existing target_dir + # - Then set is_example = "False" + # ------------------------------------------------------------------------- + submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( + fn=update_log, inputs=[], outputs=[log_output] + ).then( + fn=gradio_demo, + inputs=[ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + ], + outputs=[reconstruction_output, log_output, frame_filter], + ).then( + fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False" + ) + + # ------------------------------------------------------------------------- + # Real-time Visualization Updates + # ------------------------------------------------------------------------- + conf_thres.change( + update_visualization, + [ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + [reconstruction_output, log_output], + ) + frame_filter.change( + update_visualization, + [ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + [reconstruction_output, log_output], + ) + mask_black_bg.change( + update_visualization, + [ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + [reconstruction_output, log_output], + ) + mask_white_bg.change( + update_visualization, + [ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + [reconstruction_output, log_output], + ) + show_cam.change( + update_visualization, + [ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + [reconstruction_output, log_output], + ) + mask_sky.change( + update_visualization, + [ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + [reconstruction_output, log_output], + ) + prediction_mode.change( + update_visualization, + [ + target_dir_output, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, + ], + [reconstruction_output, log_output], + ) + + # ------------------------------------------------------------------------- + # Auto-update gallery whenever user uploads or changes their files + # ------------------------------------------------------------------------- + input_video.change( + fn=update_gallery_on_upload, + inputs=[input_video, input_images], + outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], + ) + input_images.change( + fn=update_gallery_on_upload, + inputs=[input_video, input_images], + outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], + ) + + demo.queue(max_size=20).launch(show_error=True, share=True) diff --git a/demo_viser.py b/demo_viser.py new file mode 100644 index 0000000000000000000000000000000000000000..e0211dac21d980c1b048ae15288a28e128f1f342 --- /dev/null +++ b/demo_viser.py @@ -0,0 +1,402 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import glob +import time +import threading +import argparse +from typing import List, Optional + +import numpy as np +import torch +from tqdm.auto import tqdm +import viser +import viser.transforms as viser_tf +import cv2 + + +try: + import onnxruntime +except ImportError: + print("onnxruntime not found. Sky segmentation may not work.") + +from visual_util import segment_sky, download_file_from_url +from vggt.models.vggt import VGGT +from vggt.utils.load_fn import load_and_preprocess_images +from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map +from vggt.utils.pose_enc import pose_encoding_to_extri_intri + + +def viser_wrapper( + pred_dict: dict, + port: int = 8080, + init_conf_threshold: float = 50.0, # represents percentage (e.g., 50 means filter lowest 50%) + use_point_map: bool = False, + background_mode: bool = False, + mask_sky: bool = False, + image_folder: str = None, +): + """ + Visualize predicted 3D points and camera poses with viser. + + Args: + pred_dict (dict): + { + "images": (S, 3, H, W) - Input images, + "world_points": (S, H, W, 3), + "world_points_conf": (S, H, W), + "depth": (S, H, W, 1), + "depth_conf": (S, H, W), + "extrinsic": (S, 3, 4), + "intrinsic": (S, 3, 3), + } + port (int): Port number for the viser server. + init_conf_threshold (float): Initial percentage of low-confidence points to filter out. + use_point_map (bool): Whether to visualize world_points or use depth-based points. + background_mode (bool): Whether to run the server in background thread. + mask_sky (bool): Whether to apply sky segmentation to filter out sky points. + image_folder (str): Path to the folder containing input images. + """ + print(f"Starting viser server on port {port}") + + server = viser.ViserServer(host="0.0.0.0", port=port) + server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") + + # Unpack prediction dict + images = pred_dict["images"] # (S, 3, H, W) + world_points_map = pred_dict["world_points"] # (S, H, W, 3) + conf_map = pred_dict["world_points_conf"] # (S, H, W) + + depth_map = pred_dict["depth"] # (S, H, W, 1) + depth_conf = pred_dict["depth_conf"] # (S, H, W) + + extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4) + intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3) + + # Compute world points from depth if not using the precomputed point map + if not use_point_map: + world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) + conf = depth_conf + else: + world_points = world_points_map + conf = conf_map + + # Apply sky segmentation if enabled + if mask_sky and image_folder is not None: + conf = apply_sky_segmentation(conf, image_folder) + + # Convert images from (S, 3, H, W) to (S, H, W, 3) + # Then flatten everything for the point cloud + colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) + S, H, W, _ = world_points.shape + + # Flatten + points = world_points.reshape(-1, 3) + colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8) + conf_flat = conf.reshape(-1) + + cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically + # For convenience, we store only (3,4) portion + cam_to_world = cam_to_world_mat[:, :3, :] + + # Compute scene center and recenter + scene_center = np.mean(points, axis=0) + points_centered = points - scene_center + cam_to_world[..., -1] -= scene_center + + # Store frame indices so we can filter by frame + frame_indices = np.repeat(np.arange(S), H * W) + + # Build the viser GUI + gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True) + + # Now the slider represents percentage of points to filter out + gui_points_conf = server.gui.add_slider( + "Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold + ) + + gui_frame_selector = server.gui.add_dropdown( + "Show Points from Frames", options=["All"] + [str(i) for i in range(S)], initial_value="All" + ) + + # Create the main point cloud handle + # Compute the threshold value as the given percentile + init_threshold_val = np.percentile(conf_flat, init_conf_threshold) + init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1) + point_cloud = server.scene.add_point_cloud( + name="viser_pcd", + points=points_centered[init_conf_mask], + colors=colors_flat[init_conf_mask], + point_size=0.001, + point_shape="circle", + ) + + # We will store references to frames & frustums so we can toggle visibility + frames: List[viser.FrameHandle] = [] + frustums: List[viser.CameraFrustumHandle] = [] + + def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None: + """ + Add camera frames and frustums to the scene. + extrinsics: (S, 3, 4) + images_: (S, 3, H, W) + """ + # Clear any existing frames or frustums + for f in frames: + f.remove() + frames.clear() + for fr in frustums: + fr.remove() + frustums.clear() + + # Optionally attach a callback that sets the viewpoint to the chosen camera + def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None: + @frustum.on_click + def _(_) -> None: + for client in server.get_clients().values(): + client.camera.wxyz = frame.wxyz + client.camera.position = frame.position + + img_ids = range(S) + for img_id in tqdm(img_ids): + cam2world_3x4 = extrinsics[img_id] + T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4) + + # Add a small frame axis + frame_axis = server.scene.add_frame( + f"frame_{img_id}", + wxyz=T_world_camera.rotation().wxyz, + position=T_world_camera.translation(), + axes_length=0.05, + axes_radius=0.002, + origin_radius=0.002, + ) + frames.append(frame_axis) + + # Convert the image for the frustum + img = images_[img_id] # shape (3, H, W) + img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) + h, w = img.shape[:2] + + # If you want correct FOV from intrinsics, do something like: + # fx = intrinsics_cam[img_id, 0, 0] + # fov = 2 * np.arctan2(h/2, fx) + # For demonstration, we pick a simple approximate FOV: + fy = 1.1 * h + fov = 2 * np.arctan2(h / 2, fy) + + # Add the frustum + frustum_cam = server.scene.add_camera_frustum( + f"frame_{img_id}/frustum", fov=fov, aspect=w / h, scale=0.05, image=img, line_width=1.0 + ) + frustums.append(frustum_cam) + attach_callback(frustum_cam, frame_axis) + + def update_point_cloud() -> None: + """Update the point cloud based on current GUI selections.""" + # Here we compute the threshold value based on the current percentage + current_percentage = gui_points_conf.value + threshold_val = np.percentile(conf_flat, current_percentage) + + print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%") + + conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5) + + if gui_frame_selector.value == "All": + frame_mask = np.ones_like(conf_mask, dtype=bool) + else: + selected_idx = int(gui_frame_selector.value) + frame_mask = frame_indices == selected_idx + + combined_mask = conf_mask & frame_mask + point_cloud.points = points_centered[combined_mask] + point_cloud.colors = colors_flat[combined_mask] + + @gui_points_conf.on_update + def _(_) -> None: + update_point_cloud() + + @gui_frame_selector.on_update + def _(_) -> None: + update_point_cloud() + + @gui_show_frames.on_update + def _(_) -> None: + """Toggle visibility of camera frames and frustums.""" + for f in frames: + f.visible = gui_show_frames.value + for fr in frustums: + fr.visible = gui_show_frames.value + + # Add the camera frames to the scene + visualize_frames(cam_to_world, images) + + print("Starting viser server...") + # If background_mode is True, spawn a daemon thread so the main thread can continue. + if background_mode: + + def server_loop(): + while True: + time.sleep(0.001) + + thread = threading.Thread(target=server_loop, daemon=True) + thread.start() + else: + while True: + time.sleep(0.01) + + return server + + +# Helper functions for sky segmentation + + +def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray: + """ + Apply sky segmentation to confidence scores. + + Args: + conf (np.ndarray): Confidence scores with shape (S, H, W) + image_folder (str): Path to the folder containing input images + + Returns: + np.ndarray: Updated confidence scores with sky regions masked out + """ + S, H, W = conf.shape + sky_masks_dir = image_folder.rstrip("/") + "_sky_masks" + os.makedirs(sky_masks_dir, exist_ok=True) + + # Download skyseg.onnx if it doesn't exist + if not os.path.exists("skyseg.onnx"): + print("Downloading skyseg.onnx...") + download_file_from_url("https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx") + + skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") + image_files = sorted(glob.glob(os.path.join(image_folder, "*"))) + sky_mask_list = [] + + print("Generating sky masks...") + for i, image_path in enumerate(tqdm(image_files[:S])): # Limit to the number of images in the batch + image_name = os.path.basename(image_path) + mask_filepath = os.path.join(sky_masks_dir, image_name) + + if os.path.exists(mask_filepath): + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + else: + sky_mask = segment_sky(image_path, skyseg_session, mask_filepath) + + # Resize mask to match HΓ—W if needed + if sky_mask.shape[0] != H or sky_mask.shape[1] != W: + sky_mask = cv2.resize(sky_mask, (W, H)) + + sky_mask_list.append(sky_mask) + + # Convert list to numpy array with shape SΓ—HΓ—W + sky_mask_array = np.array(sky_mask_list) + # Apply sky mask to confidence scores + sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) + conf = conf * sky_mask_binary + + print("Sky segmentation applied successfully") + return conf + + +parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization") +parser.add_argument( + "--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images" +) +parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points") +parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode") +parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server") +parser.add_argument( + "--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out" +) +parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") + + +def main(): + """ + Main function for the VGGT demo with viser for 3D visualization. + + This function: + 1. Loads the VGGT model + 2. Processes input images from the specified folder + 3. Runs inference to generate 3D points and camera poses + 4. Optionally applies sky segmentation to filter out sky points + 5. Visualizes the results using viser + + Command-line arguments: + --image_folder: Path to folder containing input images + --use_point_map: Use point map instead of depth-based points + --background_mode: Run the viser server in background mode + --port: Port number for the viser server + --conf_threshold: Initial percentage of low-confidence points to filter out + --mask_sky: Apply sky segmentation to filter out sky points + """ + args = parser.parse_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + print("Initializing and loading VGGT model...") + # model = VGGT.from_pretrained("facebook/VGGT-1B") + + model = VGGT() + _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" + model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) + + model.eval() + model = model.to(device) + + # Use the provided image folder path + print(f"Loading images from {args.image_folder}...") + image_names = glob.glob(os.path.join(args.image_folder, "*")) + print(f"Found {len(image_names)} images") + + images = load_and_preprocess_images(image_names).to(device) + print(f"Preprocessed images shape: {images.shape}") + + print("Running inference...") + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + predictions = model(images) + + print("Converting pose encoding to extrinsic and intrinsic matrices...") + extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) + predictions["extrinsic"] = extrinsic + predictions["intrinsic"] = intrinsic + + print("Processing model outputs...") + for key in predictions.keys(): + if isinstance(predictions[key], torch.Tensor): + predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension and convert to numpy + + if args.use_point_map: + print("Visualizing 3D points from point map") + else: + print("Visualizing 3D points by unprojecting depth map by cameras") + + if args.mask_sky: + print("Sky segmentation enabled - will filter out sky points") + + print("Starting viser visualization...") + + viser_server = viser_wrapper( + predictions, + port=args.port, + init_conf_threshold=args.conf_threshold, + use_point_map=args.use_point_map, + background_mode=args.background_mode, + mask_sky=args.mask_sky, + image_folder=args.image_folder, + ) + print("Visualization complete") + + +if __name__ == "__main__": + main() diff --git a/docs/package.md b/docs/package.md new file mode 100644 index 0000000000000000000000000000000000000000..356df89b613f9b48dd47d8b993bf792715237a6b --- /dev/null +++ b/docs/package.md @@ -0,0 +1,45 @@ +# Alternative Installation Methods + +This document explains how to install VGGT as a package using different package managers. + +## Prerequisites + +Before installing VGGT as a package, you need to install PyTorch and torchvision. We don't list these as dependencies to avoid CUDA version mismatches. Install them first, with an example as: + +```bash +# install pytorch 2.3.1 with cuda 12.1 +pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121 +``` + +## Installation Options + +### Install with pip + +The simplest way to install VGGT is using pip: + +```bash +pip install -e . +``` + +### Install and run with pixi + +[Pixi](https://pixi.sh) is a package management tool for creating reproducible environments. + +1. First, [download and install pixi](https://pixi.sh/latest/get_started/) +2. Then run: + +```bash +pixi run -e python demo_gradio.py +``` + +### Install and run with uv + +[uv](https://docs.astral.sh/uv/) is a fast Python package installer and resolver. + +1. First, [install uv](https://docs.astral.sh/uv/getting-started/installation/) +2. Then run: + +```bash +uv run --extra demo demo_gradio.py +``` + diff --git a/examples/demo.py b/examples/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..c62c8a4e4d2dabcf565d42dc42fff3cc940437b8 --- /dev/null +++ b/examples/demo.py @@ -0,0 +1,20 @@ +import torch +from vggt.models.vggt import VGGT +from vggt.utils.load_fn import load_and_preprocess_images + +device = "cuda" if torch.cuda.is_available() else "cpu" +# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) +dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + +# Initialize the model and load the pretrained weights. +# This will automatically download the model weights the first time it's run, which may take a while. +model = VGGT.from_pretrained("facebook/VGGT-1B").to(device) + +# Load and preprocess example images (replace with your own image paths) +image_names = ["path/to/imageA.png", "path/to/imageB.png", "path/to/imageC.png"] +images = load_and_preprocess_images(image_names).to(device) + +with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + # Predict attributes including cameras, depth maps, and point maps. + predictions = model(images) \ No newline at end of file diff --git a/examples/kitchen/images/00.png b/examples/kitchen/images/00.png new file mode 100644 index 0000000000000000000000000000000000000000..8c6483769ab24d8a6f7fe37529d64989f9b78159 --- /dev/null +++ b/examples/kitchen/images/00.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54527a575988094058cdc1975b421c48e0f446726473d0ac21ea55ecb24e96a7 +size 691089 diff --git a/examples/kitchen/images/01.png b/examples/kitchen/images/01.png new file mode 100644 index 0000000000000000000000000000000000000000..7609d57c771a63ceeb31022d99a93c2a0e88770f --- /dev/null +++ b/examples/kitchen/images/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ad4c6d74c16661ed427f8100124aaf53e7fd0577b32c362f13559dfad7027a7 +size 726182 diff --git a/examples/kitchen/images/02.png b/examples/kitchen/images/02.png new file mode 100644 index 0000000000000000000000000000000000000000..2f72261237cfc1f452d5aadceef7a4f427996f04 --- /dev/null +++ b/examples/kitchen/images/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:596bd54d26f889fc80cedee81d95dda709fa134d86ac199b6509337e413246d5 +size 789490 diff --git a/examples/kitchen/images/03.png b/examples/kitchen/images/03.png new file mode 100644 index 0000000000000000000000000000000000000000..40b21dae14acbcd5f0366b330ee6a57275992b7f --- /dev/null +++ b/examples/kitchen/images/03.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78193756310d9abaf81fa28902cf0b284260a0a916b085a7c08a4723eead1dd6 +size 828488 diff --git a/examples/kitchen/images/04.png b/examples/kitchen/images/04.png new file mode 100644 index 0000000000000000000000000000000000000000..98b6c88ccf1513aad6ce4fa15f7fd7ccf4a3a005 --- /dev/null +++ b/examples/kitchen/images/04.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca551254002a318228e19e46982813f3e489828796e98547ff632043f3002f9d +size 723884 diff --git a/examples/kitchen/images/05.png b/examples/kitchen/images/05.png new file mode 100644 index 0000000000000000000000000000000000000000..0bda016502e4cf6cb3b8c137b1ff860ef7044bfe --- /dev/null +++ b/examples/kitchen/images/05.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8dcd116d782d32b404d7e4aa69f462abbd048a0d8727440ec37f18cc4548ee4 +size 759467 diff --git a/examples/kitchen/images/06.png b/examples/kitchen/images/06.png new file mode 100644 index 0000000000000000000000000000000000000000..29f4f83aecdc4212b6d1f7abe5b072a4a58f4866 --- /dev/null +++ b/examples/kitchen/images/06.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fcc2b871c6fef6f3a3e0f06a3ffc1f0eee3e40afa2461f7c7c665057decb3e6 +size 673638 diff --git a/examples/kitchen/images/07.png b/examples/kitchen/images/07.png new file mode 100644 index 0000000000000000000000000000000000000000..d79a6649230834d92bb13677196d522efd3cc110 --- /dev/null +++ b/examples/kitchen/images/07.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28d21898de0e6370790839a40f7f45d84fbb3e6ff5809f0a0e14bd01bdef730e +size 855991 diff --git a/examples/kitchen/images/08.png b/examples/kitchen/images/08.png new file mode 100644 index 0000000000000000000000000000000000000000..c48db949284ddc2e8ab85c7cc768f968230b12d6 --- /dev/null +++ b/examples/kitchen/images/08.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0137a2bb3eb3e691d8d8b1f8884a9c8f99748888b1db770091d7acdf35fe8efa +size 676557 diff --git a/examples/kitchen/images/09.png b/examples/kitchen/images/09.png new file mode 100644 index 0000000000000000000000000000000000000000..cfb8361c95613d9645b72b6fc5f6988de30ab036 --- /dev/null +++ b/examples/kitchen/images/09.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ab59c1ef85d8169b404463f01b7ae4d287da12677126b68a3dce407ca2b9077 +size 796675 diff --git a/examples/kitchen/images/10.png b/examples/kitchen/images/10.png new file mode 100644 index 0000000000000000000000000000000000000000..c9134a04e9eacaa237113aa22889ef2bbede6242 --- /dev/null +++ b/examples/kitchen/images/10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f180cbf110bc65b89ad616328ad7d076dc3901a18def4b1337a134cdf65233a0 +size 730142 diff --git a/examples/kitchen/images/11.png b/examples/kitchen/images/11.png new file mode 100644 index 0000000000000000000000000000000000000000..e2e791119fcebf2b44c371745eb5bc5e704386af --- /dev/null +++ b/examples/kitchen/images/11.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:781196eadae8d907928e877e073289c0998e2b9e513d4f7580e147d15d1ae571 +size 798727 diff --git a/examples/kitchen/images/12.png b/examples/kitchen/images/12.png new file mode 100644 index 0000000000000000000000000000000000000000..cfa388fe326f2be28a14bdeaa30106ea09916c79 --- /dev/null +++ b/examples/kitchen/images/12.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd59b24dc8962ba0fc7fbb37b53a6d76fec9730c74e7e3235a06902b250e7d44 +size 706754 diff --git a/examples/kitchen/images/13.png b/examples/kitchen/images/13.png new file mode 100644 index 0000000000000000000000000000000000000000..304bac545b1530c97e68c9ed618436ae78a1b543 --- /dev/null +++ b/examples/kitchen/images/13.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4cd39f22c766477bad741ff37a1ee5f71aecde8bb6762d869b4c9dca1ceacfb +size 755076 diff --git a/examples/kitchen/images/14.png b/examples/kitchen/images/14.png new file mode 100644 index 0000000000000000000000000000000000000000..521a170e47be595407853208d2cd39f61dd3d819 --- /dev/null +++ b/examples/kitchen/images/14.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5df1f398efc144271e342d7b65447e022a100b93b3850a755fbc66aff5fca0f2 +size 642363 diff --git a/examples/kitchen/images/15.png b/examples/kitchen/images/15.png new file mode 100644 index 0000000000000000000000000000000000000000..21bfe4b7b4fe9c76a75fc7a76798b6daa6657c0b --- /dev/null +++ b/examples/kitchen/images/15.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:325262829ddb11d1c7df1a8f1fef79a297332dad51870ab0d40a73f1dd6869b1 +size 639105 diff --git a/examples/kitchen/images/16.png b/examples/kitchen/images/16.png new file mode 100644 index 0000000000000000000000000000000000000000..1cc0b695fa9537101d30f2cd13805e9227102138 --- /dev/null +++ b/examples/kitchen/images/16.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9779a78d72fc25f2118a270f060afeacbcef149a4f012119ff041effa8727cbf +size 754320 diff --git a/examples/kitchen/images/17.png b/examples/kitchen/images/17.png new file mode 100644 index 0000000000000000000000000000000000000000..07e87ec96b631ac25fd30c3b676801aa0b2a4b7e --- /dev/null +++ b/examples/kitchen/images/17.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2549f4f505ea021eebe0bf579b969b6c162d2dee18b0c8e9d7a3c043d200e45b +size 773937 diff --git a/examples/kitchen/images/18.png b/examples/kitchen/images/18.png new file mode 100644 index 0000000000000000000000000000000000000000..5f499c47a25767b0405740b52dc62cd1ecae2799 --- /dev/null +++ b/examples/kitchen/images/18.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1c21131c4732756d5774dd732af86c1d39dea96fd2d613afd570633b3a76ef6 +size 829179 diff --git a/examples/kitchen/images/19.png b/examples/kitchen/images/19.png new file mode 100644 index 0000000000000000000000000000000000000000..0b655085a6533b76d87c163a9a38a2a5aa94cc21 --- /dev/null +++ b/examples/kitchen/images/19.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d17680e77c6cb326eb4604e29f9e532db34769ca20b938e944ab53e8bd3798e2 +size 678031 diff --git a/examples/kitchen/images/20.png b/examples/kitchen/images/20.png new file mode 100644 index 0000000000000000000000000000000000000000..cf31d355bcf197763a2f5cce2a2dcd4e698e31bc --- /dev/null +++ b/examples/kitchen/images/20.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e9c835a0e0c1bc162a8bff6b93677c58cb53afaadca260b0ca2a388565b4cc2 +size 718249 diff --git a/examples/kitchen/images/21.png b/examples/kitchen/images/21.png new file mode 100644 index 0000000000000000000000000000000000000000..3971e96e273e87cc40edafeee8f3c9f2f9d91ca4 --- /dev/null +++ b/examples/kitchen/images/21.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0747b2d1b44ef538a9aa40a067881ef9d3ed5cacbf954c926a2bdf5f29c114e6 +size 786649 diff --git a/examples/kitchen/images/22.png b/examples/kitchen/images/22.png new file mode 100644 index 0000000000000000000000000000000000000000..45696a29d0e2f8bf5759b016b109cbc22e094b0c --- /dev/null +++ b/examples/kitchen/images/22.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77a0014d7c7d5802ce23cda4e102759274fd8f4c150271a3b61cbb2fe33b69b6 +size 674666 diff --git a/examples/kitchen/images/23.png b/examples/kitchen/images/23.png new file mode 100644 index 0000000000000000000000000000000000000000..bcad95029f1aa2e57eee5b9107b54150c049cb32 --- /dev/null +++ b/examples/kitchen/images/23.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a9415e9b8f08ff298829ffac779bb1e8dedccb3bf36060d59a7da2a35c4f790 +size 651508 diff --git a/examples/kitchen/images/24.png b/examples/kitchen/images/24.png new file mode 100644 index 0000000000000000000000000000000000000000..fad4822b3ed1c066be6882c530ee665138b0a20d --- /dev/null +++ b/examples/kitchen/images/24.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5199003307466bf4706a0898f139bf3590946f255d08c6b11d5aa9eede54c83a +size 799878 diff --git a/examples/llff_fern/images/000.png b/examples/llff_fern/images/000.png new file mode 100644 index 0000000000000000000000000000000000000000..869b752dc917076a62e853fc6e7fdd2eae274556 --- /dev/null +++ b/examples/llff_fern/images/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47f447d31a84d53494045087cbb8a40b877a68a76f549af14f6bb6f490a5b05d +size 670941 diff --git a/examples/llff_fern/images/001.png b/examples/llff_fern/images/001.png new file mode 100644 index 0000000000000000000000000000000000000000..31f0f45ffe0877b354cda68cd0adaeaab946b11c --- /dev/null +++ b/examples/llff_fern/images/001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05402df1d7247e794768461571c188737dcae5fcb34400990f5751244a3e41c0 +size 665851 diff --git a/examples/llff_fern/images/002.png b/examples/llff_fern/images/002.png new file mode 100644 index 0000000000000000000000000000000000000000..0ed72fbb8310075b12bb65807282cbed9ef9cd40 --- /dev/null +++ b/examples/llff_fern/images/002.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e17135aa9b506fac24a9529ee56c37ef5a52c55498998d3de64cf3e46210dccc +size 651522 diff --git a/examples/llff_fern/images/003.png b/examples/llff_fern/images/003.png new file mode 100644 index 0000000000000000000000000000000000000000..c657345c0d2097533323dc8e2c5530f9f1ce81bd --- /dev/null +++ b/examples/llff_fern/images/003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3285c7cc6b4b75703a68f510072c5eca81cff9b983044426cbe2ca27d4e526c5 +size 653296 diff --git a/examples/llff_fern/images/004.png b/examples/llff_fern/images/004.png new file mode 100644 index 0000000000000000000000000000000000000000..705baca1830cfe2680984590af905861c64e868b --- /dev/null +++ b/examples/llff_fern/images/004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9976282206f9aff0fc3eaa3daa182ba93e0c6734c69bdf31f40989641b4f8fea +size 608918 diff --git a/examples/llff_fern/images/005.png b/examples/llff_fern/images/005.png new file mode 100644 index 0000000000000000000000000000000000000000..891222726f7b2c2f0b452fe47f78edc3fbc143d0 --- /dev/null +++ b/examples/llff_fern/images/005.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7ce2bcabcd2b2972c505e2649eb0f5b9efb30adcd455d93f5014370d53f2653 +size 632688 diff --git a/examples/llff_fern/images/006.png b/examples/llff_fern/images/006.png new file mode 100644 index 0000000000000000000000000000000000000000..b701dc88506f9b206f7e3d7d84301d7f07b0947d --- /dev/null +++ b/examples/llff_fern/images/006.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57dfa176d28662655adda9c0c04d6949424ddb3c702f533ddb332543dd1dcbdb +size 633972 diff --git a/examples/llff_fern/images/007.png b/examples/llff_fern/images/007.png new file mode 100644 index 0000000000000000000000000000000000000000..afcd5629f89d30388d0b885677ef866e74cccc03 --- /dev/null +++ b/examples/llff_fern/images/007.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da42e3dd198bc0951591b2a5e41bb15fbff18c2aef194d17a6acbf128487749e +size 632488 diff --git a/examples/llff_fern/images/008.png b/examples/llff_fern/images/008.png new file mode 100644 index 0000000000000000000000000000000000000000..708ae93b1edb171d75781e2728700ffc4ca71422 --- /dev/null +++ b/examples/llff_fern/images/008.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8b8ac9860697b9cb1bfe41b358b03aef7a97ecb2fa9af61bc6e11210d99e8be +size 632795 diff --git a/examples/llff_fern/images/009.png b/examples/llff_fern/images/009.png new file mode 100644 index 0000000000000000000000000000000000000000..ad9ec7c0f0030b6656d3da6bd928737135ee5c20 --- /dev/null +++ b/examples/llff_fern/images/009.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:beb825c8fee0b21801bca59ddddf65e560bd87fdc6823ec733cb8e6be05002c9 +size 639930 diff --git a/examples/llff_fern/images/010.png b/examples/llff_fern/images/010.png new file mode 100644 index 0000000000000000000000000000000000000000..a8f1cd2d451011202fa1048218a3b5aaf51b670d --- /dev/null +++ b/examples/llff_fern/images/010.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:390054ae2969ce5ed0e8ed9b9c0501c527028b565e18807e0acd5d62a4627dae +size 637464 diff --git a/examples/llff_fern/images/011.png b/examples/llff_fern/images/011.png new file mode 100644 index 0000000000000000000000000000000000000000..6911b0a97f5aab895729ecf08c7e3de45020686a --- /dev/null +++ b/examples/llff_fern/images/011.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3875f648038011549183d4c6a273be3d75e868e7f4971b474089121acbf8d52 +size 618042 diff --git a/examples/llff_fern/images/012.png b/examples/llff_fern/images/012.png new file mode 100644 index 0000000000000000000000000000000000000000..39646f4ec1f0fc1953ce10408a25ca0b63126797 --- /dev/null +++ b/examples/llff_fern/images/012.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e003aa530890edaf01a34701e0a83f21d90d41e36e80b91dc5aba9a055a72063 +size 647085 diff --git a/examples/llff_fern/images/013.png b/examples/llff_fern/images/013.png new file mode 100644 index 0000000000000000000000000000000000000000..27afc368687492eed150f85c7385a12e10955fc4 --- /dev/null +++ b/examples/llff_fern/images/013.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4c393c28985d237c9152b7e053dbd06257904f6415696a8d710c038d9c45885 +size 650231 diff --git a/examples/llff_fern/images/014.png b/examples/llff_fern/images/014.png new file mode 100644 index 0000000000000000000000000000000000000000..23a06c74337b0f5678fbbdbe825620165614f3c2 --- /dev/null +++ b/examples/llff_fern/images/014.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8e538e8b96199689a92a718d23890f4bd726ceb15e4e649298545c5b038cdbf +size 640415 diff --git a/examples/llff_fern/images/015.png b/examples/llff_fern/images/015.png new file mode 100644 index 0000000000000000000000000000000000000000..e25123ee7ef849dcbdaa416e1020ec474b99d6a3 --- /dev/null +++ b/examples/llff_fern/images/015.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c67170a6506eaca89ad231c4d247f0d231199f79602df3f98f9b822f29ffe8df +size 631546 diff --git a/examples/llff_fern/images/016.png b/examples/llff_fern/images/016.png new file mode 100644 index 0000000000000000000000000000000000000000..555ec71772302ac55c1ea4a89cd8fd9eda80baac --- /dev/null +++ b/examples/llff_fern/images/016.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b94154b7a7e3e77ad23aff47095a5b40d4c4ed265a42093173cffdc6ad46475c +size 638566 diff --git a/examples/llff_fern/images/017.png b/examples/llff_fern/images/017.png new file mode 100644 index 0000000000000000000000000000000000000000..c5582c0cf65fc98325144c56a70ef37cbdcabe09 --- /dev/null +++ b/examples/llff_fern/images/017.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:762e9e73f3e280c2f8cb81639c8c0716dfd6ebb44afaf4fe11a3a3b3fe3dc8dd +size 642341 diff --git a/examples/llff_fern/images/018.png b/examples/llff_fern/images/018.png new file mode 100644 index 0000000000000000000000000000000000000000..3dcaae747f3605fa5ef1bec9083650e7bac4c762 --- /dev/null +++ b/examples/llff_fern/images/018.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f52218c23e67e6f24ba259b2ce2d561f136833737ff7a08509007c460f22e7f0 +size 634383 diff --git a/examples/llff_fern/images/019.png b/examples/llff_fern/images/019.png new file mode 100644 index 0000000000000000000000000000000000000000..da96b95a55395db82b259f72e3900202c77876af --- /dev/null +++ b/examples/llff_fern/images/019.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bafc2ce3db1c4df4ab03d1ca89eecf8162527c6929d07191bc93e487b46e35cf +size 645132 diff --git a/examples/llff_flower/images/000.png b/examples/llff_flower/images/000.png new file mode 100644 index 0000000000000000000000000000000000000000..17190edb79c9cf8e73fc589fe55eba3a36932560 --- /dev/null +++ b/examples/llff_flower/images/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44fd722188241bb64f6011382d32940c1b4e61214938bf688f793e29adef580e +size 655368 diff --git a/examples/llff_flower/images/001.png b/examples/llff_flower/images/001.png new file mode 100644 index 0000000000000000000000000000000000000000..94351623cd68a8904a1053ca9c26041f2da35666 --- /dev/null +++ b/examples/llff_flower/images/001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1e1a0b36fd63be4f7e7fc423981261ef3382751b31a8bd0871a826295ce9309 +size 644365 diff --git a/examples/llff_flower/images/002.png b/examples/llff_flower/images/002.png new file mode 100644 index 0000000000000000000000000000000000000000..2e45b7d6c27bfb5792db050e4a6f396f448b305d --- /dev/null +++ b/examples/llff_flower/images/002.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0818138df11e5ed71b6e56d9222e1f7c9d9925a9d223b86a52a0b90a64edcb75 +size 658041 diff --git a/examples/llff_flower/images/003.png b/examples/llff_flower/images/003.png new file mode 100644 index 0000000000000000000000000000000000000000..f45bcce166a8a7037d8579d3d563d970978e1ff5 --- /dev/null +++ b/examples/llff_flower/images/003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1de85e5a9204f20a8adcc4f94d8b18ca0d62eea01e0603f31dcb285c09e4c87f +size 648040 diff --git a/examples/llff_flower/images/004.png b/examples/llff_flower/images/004.png new file mode 100644 index 0000000000000000000000000000000000000000..647f06ef1e6acbbec2e33d5ef59dcc74672f0b77 --- /dev/null +++ b/examples/llff_flower/images/004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3446490c61fb0be2cd46f8c2f90e0ed68d70b1941502235cbc37813f5d66fd02 +size 649558 diff --git a/examples/llff_flower/images/005.png b/examples/llff_flower/images/005.png new file mode 100644 index 0000000000000000000000000000000000000000..8fe81893ad2d8aa6a3f5463072f09aaefa890deb --- /dev/null +++ b/examples/llff_flower/images/005.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba9c3147864c9278952260ced2ab548e1929c8deae0433057172d7fb8d9a3ac5 +size 651164 diff --git a/examples/llff_flower/images/006.png b/examples/llff_flower/images/006.png new file mode 100644 index 0000000000000000000000000000000000000000..771292cd705d0acfeadf51ee0c072c6b6e379b2a --- /dev/null +++ b/examples/llff_flower/images/006.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ccbe4569017ba3c9c41069e24a45da3b7235ddb345b77f716cb6c6caddcdd35 +size 651084 diff --git a/examples/llff_flower/images/007.png b/examples/llff_flower/images/007.png new file mode 100644 index 0000000000000000000000000000000000000000..ce0f91d1d860446ea0a6833e05cea3f1e5b5de74 --- /dev/null +++ b/examples/llff_flower/images/007.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21743148bfd01a3ec8043b34aea4b31d9da168c7afa2c2a534bab708a74fab52 +size 646377 diff --git a/examples/llff_flower/images/008.png b/examples/llff_flower/images/008.png new file mode 100644 index 0000000000000000000000000000000000000000..68eeebd4052a8b4b17874c701acc85a00cfb8575 --- /dev/null +++ b/examples/llff_flower/images/008.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9a5a74135f5cca735cfe6b749e6c1146b4cb9720f0af2e8e51e8d71669e73d6 +size 654628 diff --git a/examples/llff_flower/images/009.png b/examples/llff_flower/images/009.png new file mode 100644 index 0000000000000000000000000000000000000000..9c5a1de99b3f338d2f5ec13d3a82e5d1769517ee --- /dev/null +++ b/examples/llff_flower/images/009.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d157a1c7e4c29e9ebf67ec2a277c8f11af642a363297c83b4ac368f37d5e9df +size 640699 diff --git a/examples/llff_flower/images/010.png b/examples/llff_flower/images/010.png new file mode 100644 index 0000000000000000000000000000000000000000..b84be38d85490c368e49a1371c6bd909efeb35c5 --- /dev/null +++ b/examples/llff_flower/images/010.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:896e73f03afb64531459f708f32101a8e0229f2a81589f45e45e850f80bf5b50 +size 639638 diff --git a/examples/llff_flower/images/011.png b/examples/llff_flower/images/011.png new file mode 100644 index 0000000000000000000000000000000000000000..4a0f7ad5ad2415c3a600f2a9be910fee818ecfe8 --- /dev/null +++ b/examples/llff_flower/images/011.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddc998c793e6a8f8cdfddb01095f67443c3e878b0103448e3f6284936b7d5f2e +size 641743 diff --git a/examples/llff_flower/images/012.png b/examples/llff_flower/images/012.png new file mode 100644 index 0000000000000000000000000000000000000000..3af561eff1c30443f07f49894517822e7afb6935 --- /dev/null +++ b/examples/llff_flower/images/012.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7f31473bbbb263ea3894a222ef325797e38613799ad733ed989841d4808104 +size 642073 diff --git a/examples/llff_flower/images/013.png b/examples/llff_flower/images/013.png new file mode 100644 index 0000000000000000000000000000000000000000..9a376d48357c7f25a159c46358daa79672952f09 --- /dev/null +++ b/examples/llff_flower/images/013.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc911ef48dfe67985adac61a34611b2061969719f23d8d8cd9248152fc31d6ac +size 648678 diff --git a/examples/llff_flower/images/014.png b/examples/llff_flower/images/014.png new file mode 100644 index 0000000000000000000000000000000000000000..acb8bffed882b08f979b76881174a67118921be1 --- /dev/null +++ b/examples/llff_flower/images/014.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2de7e031f20632e4400e096f480c8af05434e963aa662c547e4310cfd8be818f +size 648039 diff --git a/examples/llff_flower/images/015.png b/examples/llff_flower/images/015.png new file mode 100644 index 0000000000000000000000000000000000000000..3b23049410208005009f8e474c7b6ba4c77d57df --- /dev/null +++ b/examples/llff_flower/images/015.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:206aa0ef88508ca52f7805b250581bb50cfae06a3cd66b4f0c63dba37882f367 +size 643913 diff --git a/examples/llff_flower/images/016.png b/examples/llff_flower/images/016.png new file mode 100644 index 0000000000000000000000000000000000000000..bec732717b6b91b6f13e7ecf5639d8bf07bb8c25 --- /dev/null +++ b/examples/llff_flower/images/016.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b1c9903fdef04fb57597e3cbd640e7be8b0f6dfba96f627bd903010d28632b9 +size 653688 diff --git a/examples/llff_flower/images/017.png b/examples/llff_flower/images/017.png new file mode 100644 index 0000000000000000000000000000000000000000..1a13484c39125b46b9c6c0194468b1587cf24c35 --- /dev/null +++ b/examples/llff_flower/images/017.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a50808b930f18458fd663f10b8b95d0b28dafc592549038ab5932d58f695594 +size 643224 diff --git a/examples/llff_flower/images/018.png b/examples/llff_flower/images/018.png new file mode 100644 index 0000000000000000000000000000000000000000..404158c0be346456a380df23bab5163c6c1cb2d4 --- /dev/null +++ b/examples/llff_flower/images/018.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:467a41a9681c89375ec739e6fd7dfb1e23e690a01428a002e7dfcdebee7449ff +size 654040 diff --git a/examples/llff_flower/images/019.png b/examples/llff_flower/images/019.png new file mode 100644 index 0000000000000000000000000000000000000000..d0bdc70c2014b1735633939f5b9e8ca714d76e04 --- /dev/null +++ b/examples/llff_flower/images/019.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5056407b0f62af6ae5312a3eceebe7d28fd760eb0095cbc3e4f236fe43f83c6 +size 637466 diff --git a/examples/llff_flower/images/020.png b/examples/llff_flower/images/020.png new file mode 100644 index 0000000000000000000000000000000000000000..4cae726c13677f36e118d00a6e12ff1a8800eb96 --- /dev/null +++ b/examples/llff_flower/images/020.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11eea76784e61353b76437f74920022e8f96f05d0ed9fc6a391bb6f37e436dcc +size 646821 diff --git a/examples/llff_flower/images/021.png b/examples/llff_flower/images/021.png new file mode 100644 index 0000000000000000000000000000000000000000..3c045a57cbcc28d60025b3503de9eee037065522 --- /dev/null +++ b/examples/llff_flower/images/021.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad5ecac1a2d1c506db7efe18d1d8cfb1f9df77a512e206fafd3127a5403a4356 +size 646606 diff --git a/examples/llff_flower/images/022.png b/examples/llff_flower/images/022.png new file mode 100644 index 0000000000000000000000000000000000000000..70bb7be2593cd595c2db9e6443c8643c6aea5b18 --- /dev/null +++ b/examples/llff_flower/images/022.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04b0ab612d7280ded696d06f873d7bbd51bf9ade9f3bd1273643d08105976b0b +size 646872 diff --git a/examples/llff_flower/images/023.png b/examples/llff_flower/images/023.png new file mode 100644 index 0000000000000000000000000000000000000000..10f76f81f56a37cb12bcd9bdf0009ea02c85f86e --- /dev/null +++ b/examples/llff_flower/images/023.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b39f03bd3efb535d7cfef42ad700b6ac8138c2905387ef0e2cd88525ed23477 +size 643117 diff --git a/examples/llff_flower/images/024.png b/examples/llff_flower/images/024.png new file mode 100644 index 0000000000000000000000000000000000000000..bd447c0047dd8ebdf3c88c3e6c235e4c48a5564a --- /dev/null +++ b/examples/llff_flower/images/024.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:097b45ceaba1150e68ca7767797a8f96afb4a118288c4692ee69e363af9b05d3 +size 639450 diff --git a/examples/room/images/no_overlap_1.png b/examples/room/images/no_overlap_1.png new file mode 100644 index 0000000000000000000000000000000000000000..2c7310f5d8b088bc000faa35aba3d549f14d4367 --- /dev/null +++ b/examples/room/images/no_overlap_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bab2bf089c84f6b5ba13a2969c7e7769a864545596c462c5471dd47466b2874c +size 306500 diff --git a/examples/room/images/no_overlap_2.jpg b/examples/room/images/no_overlap_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9e6731afc699d9f96421ae60828789e9d274c734 --- /dev/null +++ b/examples/room/images/no_overlap_2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6f2cac1c3271918177eb134fc080d5a43e40f71bf2a50fda946614a4204d3de +size 275326 diff --git a/examples/room/images/no_overlap_3.jpg b/examples/room/images/no_overlap_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..766b14e84a84229637b6e2db4505a3bae9d1f047 --- /dev/null +++ b/examples/room/images/no_overlap_3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c0e52375887d44657a25a34578d79f744378b870863023fed6f86dcbd84eeb0 +size 249085 diff --git a/examples/room/images/no_overlap_4.jpg b/examples/room/images/no_overlap_4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac8e88a24a6ec0c8570a02a67b3d46a4622d88c9 --- /dev/null +++ b/examples/room/images/no_overlap_4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47c17353216ebf663ca519b2c6e5d515301995386b4aa602a9b2bd508c8bffe0 +size 230462 diff --git a/examples/room/images/no_overlap_5.jpg b/examples/room/images/no_overlap_5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a0e1a795b0794714d093235878a16122497e3d9 --- /dev/null +++ b/examples/room/images/no_overlap_5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69c2af2f336d84712b2879ea5a334da1dc1f01095eb71d9d30fa6a26e5ad66af +size 265973 diff --git a/examples/room/images/no_overlap_6.jpg b/examples/room/images/no_overlap_6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4d46981beb355e12377942421e2f8f8fb0792473 --- /dev/null +++ b/examples/room/images/no_overlap_6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01cf8937c2da430ef49e9c0cb23a8031a698eebf2dd1261a37a5c1ee28f5a7f5 +size 270884 diff --git a/examples/room/images/no_overlap_7.jpg b/examples/room/images/no_overlap_7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fa37dd2712a10dc959a44c7990f40f5042d438f3 --- /dev/null +++ b/examples/room/images/no_overlap_7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:189a30e8bd6445c972eb6a8c31581e9af6d0bbc03b0345fba5ca023e678f5492 +size 260800 diff --git a/examples/room/images/no_overlap_8.jpg b/examples/room/images/no_overlap_8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..37dcb508239b2c679eeb324ee4f5d67bf17fb83f --- /dev/null +++ b/examples/room/images/no_overlap_8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db13926aab6bcc0f7c4903839ccb0dc554ab3276fdcf73e0f304b596f5c15221 +size 191454 diff --git a/examples/single_cartoon/images/model_was_never_trained_on_single_image_or_cartoon.jpg b/examples/single_cartoon/images/model_was_never_trained_on_single_image_or_cartoon.jpg new file mode 100644 index 0000000000000000000000000000000000000000..98debd9f3342a5e0b31f975a9b64e0a596163858 --- /dev/null +++ b/examples/single_cartoon/images/model_was_never_trained_on_single_image_or_cartoon.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc9f663c119c0f5f227928bd1d8a7af43ee26092e9f4aa2c63dc15d026424afd +size 320856 diff --git a/examples/single_oil_painting/images/model_was_never_trained_on_single_image_or_oil_painting.png b/examples/single_oil_painting/images/model_was_never_trained_on_single_image_or_oil_painting.png new file mode 100644 index 0000000000000000000000000000000000000000..7835766a9fbd13a923a221f6917011a5b4d21e96 --- /dev/null +++ b/examples/single_oil_painting/images/model_was_never_trained_on_single_image_or_oil_painting.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4532995f19fdea4abd074b95447f696f4a6633340f0c4b5d4685f1316606d2d4 +size 398531 diff --git a/examples/videos/Colosseum.mp4 b/examples/videos/Colosseum.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0befee277f1e55cdedce2a768a8f9091422927bf --- /dev/null +++ b/examples/videos/Colosseum.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0aec156f9840622aa8deca50f1f4a96a05dd3a448afa79f8c4c87d60a2b8432c +size 2093930 diff --git a/examples/videos/fern.mp4 b/examples/videos/fern.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7837dd7ec4282b4604ca9b4348177c5793c819de --- /dev/null +++ b/examples/videos/fern.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e067e66c628cfbb3037e7b65641158f0901ea10d2c3a01d08ee9aaa5b33db8b8 +size 3173234 diff --git a/examples/videos/great_wall.mp4 b/examples/videos/great_wall.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c8f7393b273ffbd0bd4a09ecd0f5c0422ebb54ff --- /dev/null +++ b/examples/videos/great_wall.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26b14bdb2b7c412fc88dc3dabcfbb9c755668d0ea9560ea2951fe3e9ddfcd05b +size 514177 diff --git a/examples/videos/kitchen.mp4 b/examples/videos/kitchen.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..034294e0d8cfbba5055891c86b9e6151d8ad70bc --- /dev/null +++ b/examples/videos/kitchen.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16044a31972ef20a902c75f2f139facc1831b8526b872e01d82f17c8a8cf0412 +size 6986086 diff --git a/examples/videos/pyramid.mp4 b/examples/videos/pyramid.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..18f4a10fe78e4117b2c504c1dfea57b31152c023 --- /dev/null +++ b/examples/videos/pyramid.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b1336a7625ae0909fa8b1cc10558ce8d9f856e56bca49461d0618fd28a7597e +size 1130332 diff --git a/examples/videos/room.mp4 b/examples/videos/room.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..48148ccee74747c98f8fcda2ffabf29ccf069595 --- /dev/null +++ b/examples/videos/room.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63b06d3d0a34614cb3c9385dc583a598614ec87589c5f8e5954ea6d79844326e +size 273800 diff --git a/examples/videos/single_cartoon.mp4 b/examples/videos/single_cartoon.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..dde048c13eec187c62836f27421907640201164b --- /dev/null +++ b/examples/videos/single_cartoon.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5418685a991dc3dba680cc1e2708214f479fc658ddb712344edc4274d373548 +size 228582 diff --git a/examples/videos/single_oil_painting.mp4 b/examples/videos/single_oil_painting.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..63bd5ed270d3aa428602a1adf0acc4a7c2eecd59 --- /dev/null +++ b/examples/videos/single_oil_painting.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97c3ab94b7970cbe7960eb2bedf8a38343a909db035d4776aff5ba51cb3daa65 +size 794368 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..81d4f1de65b9218aaf9f2c8c6a3596c9cfd19d48 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,52 @@ +[project] +authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}] +dependencies = [ + "numpy<2", + "Pillow", + "huggingface_hub", + "einops", + "safetensors", + "opencv-python", +] +name = "vggt" +requires-python = ">= 3.10" +version = "0.0.1" + +[project.optional-dependencies] +demo = [ + "gradio==5.17.1", + "viser==0.2.23", + "tqdm", + "hydra-core", + "omegaconf", + "opencv-python", + "scipy", + "onnxruntime", + "requests", + "trimesh", + "matplotlib", +] + +# Using setuptools as the build backend +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +# setuptools configuration +[tool.setuptools.packages.find] +where = ["."] +include = ["vggt*"] + +# Pixi configuration +[tool.pixi.workspace] +channels = ["conda-forge"] +platforms = ["linux-64"] + +[tool.pixi.pypi-dependencies] +vggt = { path = ".", editable = true } + +[tool.pixi.environments] +default = { solve-group = "default" } +demo = { features = ["demo"], solve-group = "default" } + +[tool.pixi.tasks] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1cb61a2b7fd415bef4e77df43fc0ec5b53c63d7b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +torch==2.3.1 +torchvision==0.18.1 +numpy==1.26.1 +Pillow +huggingface_hub +einops +safetensors diff --git a/requirements_demo.txt b/requirements_demo.txt new file mode 100644 index 0000000000000000000000000000000000000000..680c444a2c4e3f23491bc94712d2dcfab91f04c3 --- /dev/null +++ b/requirements_demo.txt @@ -0,0 +1,16 @@ +gradio==5.17.1 +viser==0.2.23 +tqdm +hydra-core +omegaconf +opencv-python +scipy +onnxruntime +requests +trimesh +matplotlib +# feel free to skip the dependencies below if you do not need demo_colmap.py +pycolmap==3.10.0 +pyceres==2.3 +git+https://github.com/jytime/LightGlue.git#egg=lightglue + diff --git a/vggt/dependency/__init__.py b/vggt/dependency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4eab06de2c911398e0339782e327bffd4cc9f91c --- /dev/null +++ b/vggt/dependency/__init__.py @@ -0,0 +1,3 @@ +from .track_modules.track_refine import refine_track +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.base_track_predictor import BaseTrackerPredictor diff --git a/vggt/dependency/distortion.py b/vggt/dependency/distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..b3510230265dbd088844076e9d5763a35f7d712b --- /dev/null +++ b/vggt/dependency/distortion.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from typing import Union + +ArrayLike = Union[np.ndarray, torch.Tensor] + + +def _is_numpy(x: ArrayLike) -> bool: + return isinstance(x, np.ndarray) + + +def _is_torch(x: ArrayLike) -> bool: + return isinstance(x, torch.Tensor) + + +def _ensure_torch(x: ArrayLike) -> torch.Tensor: + """Convert input to torch tensor if it's not already one.""" + if _is_numpy(x): + return torch.from_numpy(x) + elif _is_torch(x): + return x + else: + return torch.tensor(x) + + +def single_undistortion(params, tracks_normalized): + """ + Apply undistortion to the normalized tracks using the given distortion parameters once. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + u_undist, v_undist = apply_distortion(params, u, v) + return torch.stack([u_undist, v_undist], dim=-1) + + +def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6): + """ + Iteratively undistort the normalized tracks using the given distortion parameters. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + max_iterations (int): Maximum number of iterations for the undistortion process. + max_step_norm (float): Maximum step norm for convergence. + rel_step_size (float): Relative step size for numerical differentiation. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + B, N, _ = tracks_normalized.shape + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + original_u, original_v = u.clone(), v.clone() + + eps = torch.finfo(u.dtype).eps + for idx in range(max_iterations): + u_undist, v_undist = apply_distortion(params, u, v) + dx = original_u - u_undist + dy = original_v - v_undist + + step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps) + step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps) + + J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u) + J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v) + J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u) + J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v) + + J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2) + + delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1)) + + u += delta[..., 0] + v += delta[..., 1] + + if torch.max((delta**2).sum(dim=-1)) < max_step_norm: + break + + return torch.stack([u, v], dim=-1) + + +def apply_distortion(extra_params, u, v): + """ + Applies radial or OpenCV distortion to the given 2D points. + + Args: + extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks. + v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks. + + Returns: + points2D (torch.Tensor): Distorted 2D points of shape BxNx2. + """ + extra_params = _ensure_torch(extra_params) + u = _ensure_torch(u) + v = _ensure_torch(v) + + num_params = extra_params.shape[1] + + if num_params == 1: + # Simple radial distortion + k = extra_params[:, 0] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k[:, None] * r2 + du = u * radial + dv = v * radial + + elif num_params == 2: + # RadialCameraModel distortion + k1, k2 = extra_params[:, 0], extra_params[:, 1] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + dv = v * radial + + elif num_params == 4: + # OpenCVCameraModel distortion + k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3]) + u2 = u * u + v2 = v * v + uv = u * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2) + dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2) + else: + raise ValueError("Unsupported number of distortion parameters") + + u = u.clone() + du + v = v.clone() + dv + + return u, v + + +if __name__ == "__main__": + import random + import pycolmap + + max_diff = 0 + for i in range(1000): + # Define distortion parameters (assuming 1 parameter for simplicity) + B = random.randint(1, 500) + track_num = random.randint(100, 1000) + params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters + tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points + + # Undistort the tracks + undistorted_tracks = iterative_undistortion(params, tracks_normalized) + + for b in range(B): + pycolmap_intri = np.array([1, 0, 0, params[b].item()]) + pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0) + + undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy()) + diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median() + max_diff = max(max_diff, diff) + print(f"diff: {diff}, max_diff: {max_diff}") + + import pdb + + pdb.set_trace() diff --git a/vggt/dependency/np_to_pycolmap.py b/vggt/dependency/np_to_pycolmap.py new file mode 100644 index 0000000000000000000000000000000000000000..def76dc080128dc53459f9263372a59e5a3f456a --- /dev/null +++ b/vggt/dependency/np_to_pycolmap.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pycolmap +from .projection import project_3D_points_np + + +def batch_np_matrix_to_pycolmap( + points3d, + extrinsics, + intrinsics, + tracks, + image_size, + masks=None, + max_reproj_error=None, + max_points3D_val=3000, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", + extra_params=None, + min_inlier_per_frame=64, + points_rgb=None, +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Check https://github.com/colmap/pycolmap for more details about its format + + NOTE that colmap expects images/cameras/points3D to be 1-indexed + so there is a +1 offset between colmap index and batch index + + + NOTE: different from VGGSfM, this function: + 1. Use np instead of torch + 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP) + """ + # points3d: Px3 + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # tracks: NxPx2 + # masks: NxP + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N, P, _ = tracks.shape + assert len(extrinsics) == N + assert len(intrinsics) == N + assert len(points3d) == P + assert image_size.shape[0] == 2 + + if max_reproj_error is not None: + projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics) + projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1) + projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 + reproj_mask = projected_diff < max_reproj_error + + if masks is not None and reproj_mask is not None: + masks = np.logical_and(masks, reproj_mask) + elif masks is not None: + masks = masks + else: + masks = reproj_mask + + assert masks is not None + + if masks.sum(1).min() < min_inlier_per_frame: + print(f"Not enough inliers per frame, skip BA.") + return None, None + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + inlier_num = masks.sum(0) + valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + + # Only add 3D points that have sufficient 2D points + for vidx in valid_idx: + # Use RGB colors if provided, otherwise use zeros + rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3) + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb) + + num_points3D = len(valid_idx) + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params) + + camera = pycolmap.Camera( + model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world + ) + + points2D_list = [] + + point2D_idx = 0 + + # NOTE point3D_id start by 1 + for point3D_id in range(1, num_points3D + 1): + original_track_idx = valid_idx[point3D_id - 1] + + if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): + if masks[fidx][original_track_idx]: + # It seems we don't need +0.5 for BA + point2D_xy = tracks[fidx][original_track_idx] + # Please note when adding the Point2D object + # It not only requires the 2D xy location, but also the id to 3D point + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: + print(f"frame {fidx + 1} is out of BA") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction, valid_mask + + +def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"): + """ + Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays. + + Args: + reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. + device (str): Ignored in NumPy version (kept for API compatibility). + camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). + + Returns: + tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. + """ + + num_images = len(reconstruction.images) + max_points3D_id = max(reconstruction.point3D_ids()) + points3D = np.zeros((max_points3D_id, 3)) + + for point3D_id in reconstruction.points3D: + points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz + + extrinsics = [] + intrinsics = [] + + extra_params = [] if camera_type == "SIMPLE_RADIAL" else None + + for i in range(num_images): + # Extract and append extrinsics + pyimg = reconstruction.images[i + 1] + pycam = reconstruction.cameras[pyimg.camera_id] + matrix = pyimg.cam_from_world.matrix() + extrinsics.append(matrix) + + # Extract and append intrinsics + calibration_matrix = pycam.calibration_matrix() + intrinsics.append(calibration_matrix) + + if camera_type == "SIMPLE_RADIAL": + extra_params.append(pycam.params[-1]) + + # Convert lists to NumPy arrays instead of torch tensors + extrinsics = np.stack(extrinsics) + intrinsics = np.stack(intrinsics) + + if camera_type == "SIMPLE_RADIAL": + extra_params = np.stack(extra_params) + extra_params = extra_params[:, None] + + return points3D, extrinsics, intrinsics, extra_params + + +######################################################## + + +def batch_np_matrix_to_pycolmap_wo_track( + points3d, + points_xyf, + points_rgb, + extrinsics, + intrinsics, + image_size, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Different from batch_np_matrix_to_pycolmap, this function does not use tracks. + + It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods. + + Do NOT use this for BA. + """ + # points3d: Px3 + # points_xyf: Px3, with x, y coordinates and frame indices + # points_rgb: Px3, rgb colors + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N = len(extrinsics) + P = len(points3d) + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + for vidx in range(P): + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx]) + + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type) + + camera = pycolmap.Camera( + model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world + ) + + points2D_list = [] + + point2D_idx = 0 + + points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx + points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0] + + for point3D_batch_idx in points_belong_to_fidx: + point3D_id = point3D_batch_idx + 1 + point2D_xyf = points_xyf[point3D_batch_idx] + point2D_xy = point2D_xyf[:2] + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: + print(f"frame {fidx + 1} does not have any points") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction + + +def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None): + """ + Helper function to get camera parameters based on camera type. + + Args: + fidx: Frame index + intrinsics: Camera intrinsic parameters + camera_type: Type of camera model + extra_params: Additional parameters for certain camera types + + Returns: + pycolmap_intri: NumPy array of camera parameters + """ + if camera_type == "PINHOLE": + pycolmap_intri = np.array( + [intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]] + ) + elif camera_type == "SIMPLE_PINHOLE": + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]) + elif camera_type == "SIMPLE_RADIAL": + raise NotImplementedError("SIMPLE_RADIAL is not supported yet") + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]]) + else: + raise ValueError(f"Camera type {camera_type} is not supported yet") + + return pycolmap_intri diff --git a/vggt/dependency/projection.py b/vggt/dependency/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..a98082dc2f5b3c057b398a03ab13dba470f4a111 --- /dev/null +++ b/vggt/dependency/projection.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from .distortion import apply_distortion + + +def img_from_cam_np( + intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0 +) -> np.ndarray: + """ + Apply intrinsics (and optional radial distortion) to camera-space points. + + Args + ---- + intrinsics : (B,3,3) camera matrix K. + points_cam : (B,3,N) homogeneous camera coords (x, y, z)α΅€. + extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None. + default : value used for np.nan replacement. + + Returns + ------- + points2D : (B,N,2) pixel coordinates. + """ + # 1. perspective divide ─────────────────────────────────────── + z = points_cam[:, 2:3, :] # (B,1,N) + points_cam_norm = points_cam / z # (B,3,N) + uv = points_cam_norm[:, :2, :] # (B,2,N) + + # 2. optional distortion ────────────────────────────────────── + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = np.stack([uu, vv], axis=1) # (B,2,N) + + # 3. homogeneous coords then K multiplication ───────────────── + ones = np.ones_like(uv[:, :1, :]) # (B,1,N) + points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N) + + # batched mat-mul: K Β· [u v 1]α΅€ + points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N) + points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N) + + return points2D.transpose(0, 2, 1) # (B,N,2) + + +def project_3D_points_np( + points3D: np.ndarray, + extrinsics: np.ndarray, + intrinsics: np.ndarray | None = None, + extra_params: np.ndarray | None = None, + *, + default: float = 0.0, + only_points_cam: bool = False, +): + """ + NumPy clone of ``project_3D_points``. + + Parameters + ---------- + points3D : (N,3) world-space points. + extrinsics : (B,3,4) [R|t] matrix for each of B cameras. + intrinsics : (B,3,3) K matrix (optional if you only need cam-space). + extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None. + default : value used to replace NaNs. + only_points_cam : if True, skip the projection and return points_cam with points2D as None. + + Returns + ------- + (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True, + and points_cam is (B,3,N) camera-space coordinates. + """ + # ----- 0. prep sizes ----------------------------------------------------- + N = points3D.shape[0] # #points + B = extrinsics.shape[0] # #cameras + + # ----- 1. world β†’ homogeneous ------------------------------------------- + w_h = np.ones((N, 1), dtype=points3D.dtype) + points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4) + + # broadcast to every camera (no actual copying with np.broadcast_to) ------ + points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4) + + # ----- 2. apply extrinsics (camera frame) ------------------------------ + # X_cam = E Β· X_hom + # einsum: E_(b i j) Β· X_(b n j) β†’ (b n i) + points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3) + points_cam = points_cam.transpose(0, 2, 1) # (B,3,N) + + if only_points_cam: + return None, points_cam + + # ----- 3. intrinsics + distortion --------------------------------------- + if intrinsics is None: + raise ValueError("`intrinsics` must be provided unless only_points_cam=True") + + points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default) + + return points2D, points_cam + + +def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + points3D (torch.Tensor): 3D points of shape Px3. + extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion. + default (float): Default value to replace NaNs. + only_points_cam (bool): If True, skip the projection and return points2D as None. + + Returns: + tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True, + and points_cam is of shape Bx3xN. + """ + with torch.cuda.amp.autocast(dtype=torch.double): + N = points3D.shape[0] # Number of points + B = extrinsics.shape[0] # Batch size, i.e., number of cameras + points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4 + # Reshape for batch processing + points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) + + if only_points_cam: + return None, points_cam + + # Step 2: Apply intrinsic parameters and (optional) distortion + points2D = img_from_cam(intrinsics, points_cam, extra_params, default) + + return points2D, points_cam + + +def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalize by the third coordinate (homogeneous division) + points_cam = points_cam / points_cam[:, 2:3, :] + # Extract uv + uv = points_cam[:, :2, :] + + # Apply distortion if extra_params are provided + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = torch.stack([uu, vv], dim=1) + + # Prepare points_cam for batch matrix multiplication + points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN + + # Extract x and y coordinates + points2D = points2D_homo[:, :2, :] # Bx2xN + + # Replace NaNs with default value + points2D = torch.nan_to_num(points2D, nan=default) + + return points2D.transpose(1, 2) # BxNx2 + + +if __name__ == "__main__": + # Set up example input + B, N = 24, 10240 + + for _ in range(100): + points3D = np.random.rand(N, 3).astype(np.float64) + extrinsics = np.random.rand(B, 3, 4).astype(np.float64) + intrinsics = np.random.rand(B, 3, 3).astype(np.float64) + + # Convert to torch tensors + points3D_torch = torch.tensor(points3D) + extrinsics_torch = torch.tensor(extrinsics) + intrinsics_torch = torch.tensor(intrinsics) + + # Run NumPy implementation + points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics) + + # Run torch implementation + points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch) + + # Convert torch output to numpy + points2D_torch_np = points2D_torch.detach().numpy() + points_cam_torch_np = points_cam_torch.detach().numpy() + + # Compute difference + diff = np.abs(points2D_np - points2D_torch_np) + print("Difference between NumPy and PyTorch implementations:") + print(diff) + + # Check max error + max_diff = np.max(diff) + print(f"Maximum difference: {max_diff}") + + if np.allclose(points2D_np, points2D_torch_np, atol=1e-6): + print("Implementations match closely.") + else: + print("Significant differences detected.") + + if points_cam_np is not None: + points_cam_diff = np.abs(points_cam_np - points_cam_torch_np) + print("Difference between NumPy and PyTorch camera-space coordinates:") + print(points_cam_diff) + + # Check max error + max_cam_diff = np.max(points_cam_diff) + print(f"Maximum camera-space coordinate difference: {max_cam_diff}") + + if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6): + print("Camera-space coordinates match closely.") + else: + print("Significant differences detected in camera-space coordinates.") diff --git a/vggt/dependency/track_modules/__init__.py b/vggt/dependency/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/dependency/track_modules/base_track_predictor.py b/vggt/dependency/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..8218c014e20baa646b612e368d8bdd1841658d65 --- /dev/null +++ b/vggt/dependency/track_modules/base_track_predictor.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=4, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + fine=False, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.fine = fine + + self.flows_emb_dim = latent_dim // 2 + self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 + + if self.fine: + # TODO this is the old dummy code, will remove this when we train next model + self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 + else: + self.transformer_dim += (4 - self.transformer_dim % 4) % 4 + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + if not self.fine: + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2 + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + # Construct the correlation block + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for itr in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + # Compute the correlation (check the implementation of CorrBlock) + + fcorr_fn.corr(track_feats) + fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim + + corrdim = fcorrs.shape[3] + + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + if transformer_input.shape[2] < self.transformer_dim: + # pad the features to match the dimension + pad_dim = self.transformer_dim - transformer_input.shape[2] + pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) + transformer_input = torch.cat([transformer_input, pad], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta = self.updateformer(x) + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + if not self.fine: + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + vis_e = torch.sigmoid(vis_e) + else: + vis_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat + else: + return coord_preds, vis_e diff --git a/vggt/dependency/track_modules/blocks.py b/vggt/dependency/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e0017d2c25338d0ce5d3f31e3802282259c8fa36 --- /dev/null +++ b/vggt/dependency/track_modules/blocks.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + +from .utils import bilinear_sampler + +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=4): + super(BasicEncoder, self).__init__() + + self.stride = stride + self.norm_fn = "instance" + self.in_planes = output_dim // 2 + + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(output_dim // 2, stride=1) + self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) + self.layer3 = self._make_layer(output_dim, stride=2) + self.layer4 = self._make_layer(output_dim, stride=2) + + self.conv2 = nn.Conv2d( + output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros" + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + + a = _bilinear_intepolate(a, self.stride, H, W) + b = _bilinear_intepolate(b, self.stride, H, W) + c = _bilinear_intepolate(c, self.stride, H, W) + d = _bilinear_intepolate(d, self.stride, H, W) + + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + return x + + +class ShallowEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): + super(ShallowEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = output_dim + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(output_dim, stride=2) + + self.layer2 = self._make_layer(output_dim, stride=2) + self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + self.in_planes = dim + + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + return layer1 + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + tmp = self.layer1(x) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = self.layer2(tmp) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = None + x = self.conv2(x) + x + + x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True) + + return x + + +def _bilinear_intepolate(x, stride, H, W): + return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True) + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, input_tensor, mask=None): + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + flow = self.flow_head(tokens) + return flow + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.padding_mode = padding_mode + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.multiple_track_feats = multiple_track_feats + + self.fmaps_pyramid.append(fmaps) + for i in range(self.num_levels - 1): + fmaps_ = fmaps.reshape(B * S, C, H, W) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + *_, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode) + corrs = corrs.view(B, S, N, -1) + + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 + return out + + def corr(self, targets): + B, S, N, C = targets.shape + if self.multiple_track_feats: + targets_split = targets.split(C // self.num_levels, dim=-1) + B, S, N, C = targets_split[0].shape + + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + *_, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) + if self.multiple_track_feats: + fmap1 = targets_split[i] + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) diff --git a/vggt/dependency/track_modules/modules.py b/vggt/dependency/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e89b26edc7717f04a897977041f26e5c4f1c52b2 --- /dev/null +++ b/vggt/dependency/track_modules/modules.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/vggt/dependency/track_modules/track_refine.py b/vggt/dependency/track_modules/track_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..54a7ace1d49686304e5fbf28c33168667c28e181 --- /dev/null +++ b/vggt/dependency/track_modules/track_refine.py @@ -0,0 +1,419 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from PIL import Image +import os +from typing import Union, Tuple + + +def refine_track( + images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960 +): + """ + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] + + if chunk < 0: + # Extract image patches based on top left corners + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + else: + patches = extracted_patches.reshape(B * S * N, C_in, psize, psize) + + patch_feat_list = [] + for p in torch.split(patches, chunk): + patch_feat_list += [fine_fnet(p)] + patch_feat = torch.cat(patch_feat_list, 0) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) + + return refined_tracks, score + + +def refine_track_v0( + images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6 +): + """ + COPIED FROM VGGSfM + + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + + # Extract image patches based on top left corners + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] + + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) + + return refined_tracks, score + + +################################## NOTE: NOT USED ################################## + + +def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out): + """ + Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, + given the query point features and reference frame feature maps + """ + + from kornia.utils.grid import create_meshgrid + from kornia.geometry.subpix import dsnt + + # query_point_feat initial shape: B x N x C_out, + # query_point_feat indicates the feat at the coorponsing query points + # Therefore we don't have S dimension here + query_point_feat = query_point_feat.reshape(B, N, C_out) + # reshape and expand to B x (S-1) x N x C_out + query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) + # and reshape to (B*(S-1)*N) x C_out + query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) + + # Radius and size for computing the score + ssize = sradius * 2 + 1 + + # Reshape, you know it, so many reshaping operations + patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) + + # Again, we unfold the patches to smaller patches + # so that we can then focus on smaller patches + # patch_feat_unfold shape: + # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize + # well a bit scary, but actually not + patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) + + # Do the same stuffs above, i.e., the same as extracting patches + fine_prediction_floor = fine_pred_track.floor().int() + fine_level_floor_topleft = fine_prediction_floor - sradius + + # Clamp to ensure the smaller patch is valid + fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) + fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) + + # Prepare the batch indices and xy locations + + batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN + batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N + y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices + x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices + + reference_frame_feat = patch_feat_unfold.reshape( + B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize + ) + + # Note again, according to pytorch convention + # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] + reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices] + reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) + # pick the frames other than the first one, so we have S-1 frames here + reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize) + + # Compute similarity + sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) + softmax_temp = 1.0 / C_out**0.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) + # 2D heatmaps + heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize + + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape( + 1, -1, 2 + ) + + var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2 + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability + + score = std.reshape(B, S - 1, N) + # set score as 1 for the query frame + score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) + + return score + + +def extract_glimpse( + tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None +): + B, C, W, H = tensor.shape + + h, w = size + xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 + ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 + + vy, vx = torch.meshgrid(ys, xs) + grid = torch.stack([vx, vy], dim=-1) # h, w, 2 + grid = grid[None] + + B, N, _ = offsets.shape + + offsets = offsets.reshape((B * N), 1, 1, 2) + offsets_grid = offsets + grid + + # normalised grid to [-1, 1] + offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2]) + + # BxCxHxW -> Bx1xCxHxW + tensor = tensor[:, None] + + # Bx1xCxHxW -> BxNxCxHxW + tensor = tensor.expand(-1, N, -1, -1, -1) + + # BxNxCxHxW -> (B*N)xCxHxW + tensor = tensor.reshape((B * N), C, W, H) + + sampled = torch.nn.functional.grid_sample( + tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode + ) + + # NOTE: I am not sure it should be h, w or w, h here + # but okay for sqaures + sampled = sampled.reshape(B, N, C, h, w) + + return sampled diff --git a/vggt/dependency/track_modules/utils.py b/vggt/dependency/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d8954e87beb85e71c5fa4b5d7eb4f2b476680e6f --- /dev/null +++ b/vggt/dependency/track_modules/utils.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/PoseDiffusion +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union +from einops import rearrange, repeat + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device) + else: + coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) + + coords -= 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/vggt/dependency/track_predict.py b/vggt/dependency/track_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..c15c23fea612acb9383d7f03d7779b6d0f2dbf82 --- /dev/null +++ b/vggt/dependency/track_predict.py @@ -0,0 +1,326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from .vggsfm_utils import * + + +def predict_tracks( + images, + conf=None, + points_3d=None, + masks=None, + max_query_pts=2048, + query_frame_num=5, + keypoint_extractor="aliked+sp", + max_points_num=163840, + fine_tracking=True, + complete_non_vis=True, +): + """ + Predict tracks for the given images and masks. + + TODO: support non-square images + TODO: support masks + + + This function predicts the tracks for the given images and masks using the specified query method + and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. + + Args: + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. + points_3d: Tensor containing 3D points. Default is None. + masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None. + max_query_pts: Maximum number of query points. Default is 2048. + query_frame_num: Number of query frames to use. Default is 5. + keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". + max_points_num: Maximum number of points to process at once. Default is 163840. + fine_tracking: Whether to use fine tracking. Default is True. + complete_non_vis: Whether to augment non-visible frames. Default is True. + + Returns: + pred_tracks: Numpy array containing the predicted tracks. + pred_vis_scores: Numpy array containing the visibility scores for the tracks. + pred_confs: Numpy array containing the confidence scores for the tracks. + pred_points_3d: Numpy array containing the 3D points for the tracks. + pred_colors: Numpy array containing the point colors for the tracks. (0, 255) + """ + + device = images.device + dtype = images.dtype + tracker = build_vggsfm_tracker().to(device, dtype) + + # Find query frames + query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device) + + # Add the first image to the front if not already present + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + # TODO: add the functionality to handle the masks + keypoint_extractors = initialize_feature_extractors( + max_query_pts, extractor_method=keypoint_extractor, device=device + ) + + pred_tracks = [] + pred_vis_scores = [] + pred_confs = [] + pred_points_3d = [] + pred_colors = [] + + fmaps_for_tracker = tracker.process_images_to_fmaps(images) + + if fine_tracking: + print("For faster inference, consider disabling fine_tracking") + + for query_index in query_frame_indexes: + print(f"Predicting tracks for query frame {query_index}") + pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + pred_confs.append(pred_conf) + pred_points_3d.append(pred_point_3d) + pred_colors.append(pred_color) + + if complete_non_vis: + pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames( + pred_tracks, + pred_vis_scores, + pred_confs, + pred_points_3d, + pred_colors, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + min_vis=500, + non_vis_thresh=0.1, + device=device, + ) + + pred_tracks = np.concatenate(pred_tracks, axis=1) + pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) + pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None + pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None + pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None + + # from vggt.utils.visual_track import visualize_tracks_on_images + # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors + + +def _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, +): + """ + Process a single query frame for track prediction. + + Args: + query_index: Index of the query frame + images: Tensor of shape [S, 3, H, W] containing the input images + conf: Confidence tensor + points_3d: 3D points tensor + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + device: Device to use for computation + + Returns: + pred_track: Predicted tracks + pred_vis: Visibility scores for the tracks + pred_conf: Confidence scores for the tracks + pred_point_3d: 3D points for the tracks + pred_color: Point colors for the tracks (0, 255) + """ + frame_num, _, height, width = images.shape + + query_image = images[query_index] + query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False) + query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] + + # Extract the color at the keypoint locations + query_points_long = query_points.squeeze(0).round().long() + pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]] + pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) + + # Query the confidence and points_3d at the keypoint locations + if (conf is not None) and (points_3d is not None): + assert height == width + assert conf.shape[-2] == conf.shape[-1] + assert conf.shape[:3] == points_3d.shape[:3] + scale = conf.shape[-1] / width + + query_points_scaled = (query_points.squeeze(0) * scale).round().long() + query_points_scaled = query_points_scaled.cpu().numpy() + + pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] + pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] + + # heuristic to remove low confidence points + # should I export this as an input parameter? + valid_mask = pred_conf > 1.2 + if valid_mask.sum() > 512: + query_points = query_points[:, valid_mask] # Make sure shape is compatible + pred_conf = pred_conf[valid_mask] + pred_point_3d = pred_point_3d[valid_mask] + pred_color = pred_color[valid_mask] + else: + pred_conf = None + pred_point_3d = None + + reorder_index = calculate_index_mappings(query_index, frame_num, device=device) + + images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0) + images_feed = images_feed[None] # add batch dimension + fmaps_feed = fmaps_feed[None] # add batch dimension + + all_points_num = images_feed.shape[1] * query_points.shape[1] + + # Don't need to be scared, this is just chunking to make GPU happy + if all_points_num > max_points_num: + num_splits = (all_points_num + max_points_num - 1) // max_points_num + query_points = torch.chunk(query_points, num_splits, dim=1) + else: + query_points = [query_points] + + pred_track, pred_vis, _ = predict_tracks_in_chunks( + tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking + ) + + pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1) + + pred_track = pred_track.squeeze(0).float().cpu().numpy() + pred_vis = pred_vis.squeeze(0).float().cpu().numpy() + + return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color + + +def _augment_non_visible_frames( + pred_tracks: list, # ← running list of np.ndarrays + pred_vis_scores: list, # ← running list of np.ndarrays + pred_confs: list, # ← running list of np.ndarrays for confidence scores + pred_points_3d: list, # ← running list of np.ndarrays for 3D points + pred_colors: list, # ← running list of np.ndarrays for colors + images: torch.Tensor, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num: int, + fine_tracking: bool, + *, + min_vis: int = 500, + non_vis_thresh: float = 0.1, + device: torch.device = None, +): + """ + Augment tracking for frames with insufficient visibility. + + Args: + pred_tracks: List of numpy arrays containing predicted tracks. + pred_vis_scores: List of numpy arrays containing visibility scores. + pred_confs: List of numpy arrays containing confidence scores. + pred_points_3d: List of numpy arrays containing 3D points. + pred_colors: List of numpy arrays containing point colors. + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing confidence scores + points_3d: Tensor containing 3D points + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + min_vis: Minimum visibility threshold + non_vis_thresh: Non-visibility threshold + device: Device to use for computation + + Returns: + Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. + """ + last_query = -1 + final_trial = False + cur_extractors = keypoint_extractors # may be replaced on the final trial + + while True: + # Visibility per frame + vis_array = np.concatenate(pred_vis_scores, axis=1) + + # Count frames with sufficient visibility using numpy + sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) + non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() + + if len(non_vis_frames) == 0: + break + + print("Processing non visible frames:", non_vis_frames) + + # Decide the frames & extractor for this round + if non_vis_frames[0] == last_query: + # Same frame failed twice - final "all-in" attempt + final_trial = True + cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device) + query_frame_list = non_vis_frames # blast them all at once + else: + query_frame_list = [non_vis_frames[0]] # Process one at a time + + last_query = non_vis_frames[0] + + # Run the tracker for every selected frame + for query_index in query_frame_list: + new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + cur_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + pred_tracks.append(new_track) + pred_vis_scores.append(new_vis) + pred_confs.append(new_conf) + pred_points_3d.append(new_point_3d) + pred_colors.append(new_color) + + if final_trial: + break # Stop after final attempt + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors diff --git a/vggt/dependency/vggsfm_tracker.py b/vggt/dependency/vggsfm_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..d79aeef000dcfec506dc4afb4e500d22a758122b --- /dev/null +++ b/vggt/dependency/vggsfm_tracker.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from hydra.utils import instantiate +from omegaconf import OmegaConf + +from .track_modules.track_refine import refine_track +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackerPredictor(nn.Module): + def __init__(self, **extra_args): + super(TrackerPredictor, self).__init__() + """ + Initializes the tracker predictor. + + Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, + check track_modules/base_track_predictor.py + + Both coarse_fnet and fine_fnet are constructed as a 2D CNN network + check track_modules/blocks.py for BasicEncoder and ShallowEncoder + """ + # Define coarse predictor configuration + coarse_stride = 4 + self.coarse_down_ratio = 2 + + # Create networks directly instead of using instantiate + self.coarse_fnet = BasicEncoder(stride=coarse_stride) + self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) + + # Create fine predictor with stride = 1 + self.fine_fnet = ShallowEncoder(stride=1) + self.fine_predictor = BaseTrackerPredictor( + stride=1, + depth=4, + corr_levels=3, + corr_radius=3, + latent_dim=32, + hidden_size=256, + fine=True, + use_spaceatt=False, + ) + + def forward( + self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960 + ): + """ + Args: + images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. + query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. + fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. + coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. + inference (bool, optional): Whether to perform inference. Defaults to True. + fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. + + Returns: + tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. + """ + + if fmaps is None: + batch_num, frame_num, image_dim, height, width = images.shape + reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) + fmaps = self.process_images_to_fmaps(reshaped_image) + fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]) + + if inference: + torch.cuda.empty_cache() + + # Coarse prediction + coarse_pred_track_lists, pred_vis = self.coarse_predictor( + query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio + ) + coarse_pred_track = coarse_pred_track_lists[-1] + + if inference: + torch.cuda.empty_cache() + + if fine_tracking: + # Refine the coarse prediction + fine_pred_track, pred_score = refine_track( + images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk + ) + + if inference: + torch.cuda.empty_cache() + else: + fine_pred_track = coarse_pred_track + pred_score = torch.ones_like(pred_vis) + + return fine_pred_track, coarse_pred_track, pred_vis, pred_score + + def process_images_to_fmaps(self, images): + """ + This function processes images for inference. + + Args: + images (torch.Tensor): The images to be processed with shape S x 3 x H x W. + + Returns: + torch.Tensor: The processed feature maps. + """ + if self.coarse_down_ratio > 1: + # whether or not scale down the input images to save memory + fmaps = self.coarse_fnet( + F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True) + ) + else: + fmaps = self.coarse_fnet(images) + + return fmaps diff --git a/vggt/dependency/vggsfm_utils.py b/vggt/dependency/vggsfm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7d9ba6a28da07b7f030a17730f4826feaa828e --- /dev/null +++ b/vggt/dependency/vggsfm_utils.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pycolmap +import torch +import torch.nn.functional as F +from lightglue import ALIKED, SIFT, SuperPoint + +from .vggsfm_tracker import TrackerPredictor + +# Suppress verbose logging from dependencies +logging.getLogger("dinov2").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", message="xFormers is available") +warnings.filterwarnings("ignore", message="dinov2") + +# Constants +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +def build_vggsfm_tracker(model_path=None): + """ + Build and initialize the VGGSfM tracker. + + Args: + model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. + + Returns: + Initialized tracker model in eval mode. + """ + tracker = TrackerPredictor() + + if model_path is None: + default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" + tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) + else: + tracker.load_state_dict(torch.load(model_path)) + + tracker.eval() + return tracker + + +def generate_rank_by_dino( + images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False +): + """ + Generate a ranking of frames using DINO ViT features. + + Args: + images: Tensor of shape (S, 3, H, W) with values in range [0, 1] + query_frame_num: Number of frames to select + image_size: Size to resize images to before processing + model_name: Name of the DINO model to use + device: Device to run the model on + spatial_similarity: Whether to use spatial token similarity or CLS token similarity + + Returns: + List of frame indices ranked by their representativeness + """ + # Resize images to the target size + images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False) + + # Load DINO model + dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) + dino_v2_model.eval() + dino_v2_model = dino_v2_model.to(device) + + # Normalize images using ResNet normalization + resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) + resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) + images_resnet_norm = (images - resnet_mean) / resnet_std + + with torch.no_grad(): + frame_feat = dino_v2_model(images_resnet_norm, is_training=True) + + # Process features based on similarity type + if spatial_similarity: + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similarity matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + similarity_matrix = similarity_matrix.mean(dim=0) + else: + frame_feat = frame_feat["x_norm_clstoken"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + + distance_matrix = 100 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(-100) + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling starting from the most common frame + fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) + + # Clean up all tensors and models to free memory + del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix + del dino_v2_model + torch.cuda.empty_cache() + + return fps_idx + + +def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): + """ + Farthest point sampling algorithm to select diverse frames. + + Args: + distance_matrix: Matrix of distances between frames + num_samples: Number of frames to select + most_common_frame_index: Index of the first frame to select + + Returns: + List of selected frame indices + """ + distance_matrix = distance_matrix.clamp(min=0) + N = distance_matrix.size(0) + + # Initialize with the most common frame + selected_indices = [most_common_frame_index] + check_distances = distance_matrix[selected_indices] + + while len(selected_indices) < num_samples: + # Find the farthest point from the current set of selected points + farthest_point = torch.argmax(check_distances) + selected_indices.append(farthest_point.item()) + + check_distances = distance_matrix[farthest_point] + # Mark already selected points to avoid selecting them again + check_distances[selected_indices] = 0 + + # Break if all points have been selected + if len(selected_indices) == N: + break + + return selected_indices + + +def calculate_index_mappings(query_index, S, device=None): + """ + Construct an order that switches [query_index] and [0] + so that the content of query_index would be placed at [0]. + + Args: + query_index: Index to swap with 0 + S: Total number of elements + device: Device to place the tensor on + + Returns: + Tensor of indices with the swapped order + """ + new_order = torch.arange(S) + new_order[0] = query_index + new_order[query_index] = 0 + if device is not None: + new_order = new_order.to(device) + return new_order + + +def switch_tensor_order(tensors, order, dim=1): + """ + Reorder tensors along a specific dimension according to the given order. + + Args: + tensors: List of tensors to reorder + order: Tensor of indices specifying the new order + dim: Dimension along which to reorder + + Returns: + List of reordered tensors + """ + return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors] + + +def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"): + """ + Initialize feature extractors that can be reused based on a method string. + + Args: + max_query_num: Maximum number of keypoints to extract + det_thres: Detection threshold for keypoint extraction + extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") + device: Device to run extraction on + + Returns: + Dictionary of initialized extractors + """ + extractors = {} + methods = extractor_method.lower().split("+") + + for method in methods: + method = method.strip() + if method == "aliked": + aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["aliked"] = aliked_extractor.to(device).eval() + elif method == "sp": + sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["sp"] = sp_extractor.to(device).eval() + elif method == "sift": + sift_extractor = SIFT(max_num_keypoints=max_query_num) + extractors["sift"] = sift_extractor.to(device).eval() + else: + print(f"Warning: Unknown feature extractor '{method}', ignoring.") + + if not extractors: + print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") + aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["aliked"] = aliked_extractor.to(device).eval() + + return extractors + + +def extract_keypoints(query_image, extractors, round_keypoints=True): + """ + Extract keypoints using pre-initialized feature extractors. + + Args: + query_image: Input image tensor (3xHxW, range [0, 1]) + extractors: Dictionary of initialized extractors + + Returns: + Tensor of keypoint coordinates (1xNx2) + """ + query_points = None + + with torch.no_grad(): + for extractor_name, extractor in extractors.items(): + query_points_data = extractor.extract(query_image, invalid_mask=None) + extractor_points = query_points_data["keypoints"] + if round_keypoints: + extractor_points = extractor_points.round() + + if query_points is not None: + query_points = torch.cat([query_points, extractor_points], dim=1) + else: + query_points = extractor_points + + return query_points + + +def predict_tracks_in_chunks( + track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960 +): + """ + Process a list of query points to avoid memory issues. + + Args: + track_predictor (object): The track predictor object used for predicting tracks. + images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. + query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. + fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. + fine_tracking (bool): Whether to perform fine tracking. + num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. + + Returns: + tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. + """ + # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility + if not isinstance(query_points_list, (list, tuple)): + query_points = query_points_list + if num_splits is None: + num_splits = 1 + query_points_list = torch.chunk(query_points, num_splits, dim=1) + + # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) + if isinstance(query_points_list, tuple): + query_points_list = list(query_points_list) + + fine_pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for split_points in query_points_list: + # Feed into track predictor for each split + fine_pred_track, _, pred_vis, pred_score = track_predictor( + images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk + ) + fine_pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + # Concatenate the results from all splits + fine_pred_track = torch.cat(fine_pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + + if pred_score is not None: + pred_score = torch.cat(pred_score_list, dim=2) + else: + pred_score = None + + return fine_pred_track, pred_vis, pred_score diff --git a/vggt/heads/camera_head.py b/vggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..215adf39de23abd4975479d332250fcc3e2b54b9 --- /dev/null +++ b/vggt/heads/camera_head.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vggt.layers import Mlp +from vggt.layers.block import Block +from vggt.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/vggt/heads/dpt_head.py b/vggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4c487a38ade5e419048bdd909ecc969ab11be66f --- /dev/null +++ b/vggt/heads/dpt_head.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch(out_channels, features, expand=False) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.view(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/vggt/heads/head_act.py b/vggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489 --- /dev/null +++ b/vggt/heads/head_act.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/vggt/heads/track_head.py b/vggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f1d9bd83cca1f74f97a644a02b984904f84706 --- /dev/null +++ b/vggt/heads/track_head.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters) + + return coord_preds, vis_scores, conf_scores diff --git a/vggt/heads/track_modules/__init__.py b/vggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/vggt/heads/track_modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/vggt/heads/track_modules/base_track_predictor.py b/vggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a --- /dev/null +++ b/vggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed +from .modules import Mlp + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + and https://github.com/facebookresearch/vggsfm + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/vggt/heads/track_modules/blocks.py b/vggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..15c161c89ef99742b0f2c6f397c9121fe9301e08 --- /dev/null +++ b/vggt/heads/track_modules/blocks.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import bilinear_sampler +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Ξ”x, Ξ”y) + self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) β€” features for the current targets. + coords: Tensor (B, S, N, 2) β€” coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/vggt/heads/track_modules/modules.py b/vggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..12de4f1ad76364d4665e53ac80e1037fadf98d08 --- /dev/null +++ b/vggt/heads/track_modules/modules.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/vggt/heads/track_modules/utils.py b/vggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1fffeaedd33c7f1c2ef54220e24a2a0e5a57b2 --- /dev/null +++ b/vggt/heads/track_modules/utils.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/vggsfm +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype + ) + else: + scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/vggt/heads/utils.py b/vggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..533fc8ae67a75cd0a94d5ca96dc5a0513446c64f --- /dev/null +++ b/vggt/heads/utils.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/vggt/layers/__init__.py b/vggt/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/vggt/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/vggt/layers/attention.py b/vggt/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8e823b4b7a93cca75e4cbab1cdfbbc3121a316fa --- /dev/null +++ b/vggt/layers/attention.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: + assert pos is None + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/vggt/layers/block.py b/vggt/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5847352a1f8f5d63da28c99e94270e50ccf3aa --- /dev/null +++ b/vggt/layers/block.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + x = drop_add_residual_stochastic_depth( + x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None), + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/vggt/layers/drop_path.py b/vggt/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/vggt/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/vggt/layers/layer_scale.py b/vggt/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddfc51c3d87370d50175f5b4e649dac1c614ff9 --- /dev/null +++ b/vggt/layers/layer_scale.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/vggt/layers/mlp.py b/vggt/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/vggt/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/vggt/layers/patch_embed.py b/vggt/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..bc19605e4d6e88d06355ae3b1afddc76f595aafe --- /dev/null +++ b/vggt/layers/patch_embed.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/vggt/layers/rope.py b/vggt/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df --- /dev/null +++ b/vggt/layers/rope.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) + horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/vggt/layers/swiglu_ffn.py b/vggt/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd991e1deb87141ccd282098d4b9d38fed6ef25 --- /dev/null +++ b/vggt/layers/swiglu_ffn.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) diff --git a/vggt/layers/vision_transformer.py b/vggt/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..deda8fde42b1b5b3340132c9c75338c65c9bea3f --- /dev/null +++ b/vggt/layers/vision_transformer.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.use_reentrant = False # hardcoded to False + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/vggt/models/aggregator.py b/vggt/models/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6b25d6df44a0dbf71b214f5084b2a21fcd087e --- /dev/null +++ b/vggt/models/aggregator.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from typing import Optional, Tuple, Union, List, Dict, Any + +from vggt.layers import PatchEmbed +from vggt.layers.block import Block +from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class Aggregator(nn.Module): + """ + The Aggregator applies alternating-attention over input frames, + as described in VGGT: Visual Geometry Grounded Transformer. + + Remember to set model.train() to enable gradient checkpointing to reduce memory usage. + + Args: + img_size (int): Image size in pixels. + patch_size (int): Size of each patch for PatchEmbed. + embed_dim (int): Dimension of the token embeddings. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + num_register_tokens (int): Number of register tokens. + block_fn (nn.Module): The block type used for attention (Block by default). + qkv_bias (bool): Whether to include bias in QKV projections. + proj_bias (bool): Whether to include bias in the output projection. + ffn_bias (bool): Whether to include bias in MLP layers. + patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". + aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. + aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. + qk_norm (bool): Whether to apply QK normalization. + rope_freq (int): Base frequency for rotary embedding. -1 to disable. + init_values (float): Init scale for layer scale. + """ + + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + ): + super().__init__() + + self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) + + # Initialize rotary position embedding if frequency > 0 + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers + for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): + self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) + + self.use_reentrant = False # hardcoded to False + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed + images = (images - self._resnet_mean) / self._resnet_std + + # Reshape to [B*S, C, H, W] for patch embedding + images = images.view(B * S, C_in, H, W) + patch_tokens = self.patch_embed(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + camera_token = slice_expand_and_flatten(self.camera_token, B, S) + register_token = slice_expand_and_flatten(self.register_token, B, S) + + # Concatenate special tokens with patch tokens + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + + pos = None + if self.rope is not None: + pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + # update P because we added special tokens + _, P, C = tokens.shape + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for _ in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + elif attn_type == "global": + tokens, global_idx, global_intermediates = self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + for i in range(len(frame_intermediates)): + # concat frame and global intermediates, [B x S x P x 2C] + concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) + output_list.append(concat_inter) + + del concat_inter + del frame_intermediates + del global_intermediates + return output_list, self.patch_start_idx + + def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + if self.training: + tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant) + else: + tokens = self.frame_blocks[frame_idx](tokens, pos=pos) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + if self.training: + tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant) + else: + tokens = self.global_blocks[global_idx](tokens, pos=pos) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, global_idx, intermediates + + +def slice_expand_and_flatten(token_tensor, B, S): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + # Slice out the "query" tokens => shape (1, 1, ...) + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/vggt/models/vggt.py b/vggt/models/vggt.py new file mode 100644 index 0000000000000000000000000000000000000000..9681bcadb3a15bcd157e230a731db0e16d8547f1 --- /dev/null +++ b/vggt/models/vggt.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin # used for model hub + +from vggt.models.aggregator import Aggregator +from vggt.heads.camera_head import CameraHead +from vggt.heads.dpt_head import DPTHead +from vggt.heads.track_head import TrackHead + + +class VGGT(nn.Module, PyTorchModelHubMixin): + def __init__(self, img_size=518, patch_size=14, embed_dim=1024): + super().__init__() + + self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + self.camera_head = CameraHead(dim_in=2 * embed_dim) + self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") + self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") + self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) + + def forward(self, images: torch.Tensor, query_points: torch.Tensor = None): + """ + Forward pass of the VGGT model. + + Args: + images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. + Shape: [N, 2] or [B, N, 2], where N is the number of query points. + Default: None + + Returns: + dict: A dictionary containing the following predictions: + - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) + - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] + - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] + - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] + - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] + - images (torch.Tensor): Original input images, preserved for visualization + + If query_points is provided, also includes: + - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates + - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] + - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] + """ + + # If without batch dimension, add it + if len(images.shape) == 4: + images = images.unsqueeze(0) + if query_points is not None and len(query_points.shape) == 2: + query_points = query_points.unsqueeze(0) + + aggregated_tokens_list, patch_start_idx = self.aggregator(images) + + predictions = {} + + with torch.cuda.amp.autocast(enabled=False): + if self.camera_head is not None: + pose_enc_list = self.camera_head(aggregated_tokens_list) + predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.track_head is not None and query_points is not None: + track_list, vis, conf = self.track_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points + ) + predictions["track"] = track_list[-1] # track of the last iteration + predictions["vis"] = vis + predictions["conf"] = conf + + predictions["images"] = images + + return predictions diff --git a/vggt/utils/geometry.py b/vggt/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9989ed7c94878a07d9fed3399847ca2acd8c4c --- /dev/null +++ b/vggt/utils/geometry.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, extrinsic: np.ndarray, intrinsic: np.ndarray, eps=1e-8 +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix diff --git a/vggt/utils/helper.py b/vggt/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7b019189c85ff86645a4cf3756632aa8d4500649 --- /dev/null +++ b/vggt/utils/helper.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + + +def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: + """ + If mask has more than max_trues True values, + randomly keep only max_trues of them and set the rest to False. + """ + # 1D positions of all True entries + true_indices = np.flatnonzero(mask) # shape = (N_true,) + + # if already within budget, return as-is + if true_indices.size <= max_trues: + return mask + + # randomly pick which True positions to keep + sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) + + # build new flat mask: True only at sampled positions + limited_flat_mask = np.zeros(mask.size, dtype=bool) + limited_flat_mask[sampled_indices] = True + + # restore original shape + return limited_flat_mask.reshape(mask.shape) + + +def create_pixel_coordinate_grid(num_frames, height, width): + """ + Creates a grid of pixel coordinates and frame indices for all frames. + Returns: + tuple: A tuple containing: + - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) + with x, y coordinates and frame indices + - y_coords (numpy.ndarray): Array of y coordinates for all frames + - x_coords (numpy.ndarray): Array of x coordinates for all frames + - f_coords (numpy.ndarray): Array of frame indices for all frames + """ + # Create coordinate grids for a single frame + y_grid, x_grid = np.indices((height, width), dtype=np.float32) + x_grid = x_grid[np.newaxis, :, :] + y_grid = y_grid[np.newaxis, :, :] + + # Broadcast to all frames + x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) + y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) + + # Create frame indices and broadcast + f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] + f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) + + # Stack coordinates and frame indices + points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) + + return points_xyf diff --git a/vggt/utils/load_fn.py b/vggt/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..4d223aabdc43ac644c1b8ca376e8fec59decd084 --- /dev/null +++ b/vggt/utils/load_fn.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF +import numpy as np + + +def load_and_preprocess_images_square(image_path_list, target_size=1024): + """ + Load and preprocess images by center padding to square and resizing to target size. + Also returns the position information of original pixels after transformation. + + Args: + image_path_list (list): List of paths to image files + target_size (int, optional): Target size for both width and height. Defaults to 518. + + Returns: + tuple: ( + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), + torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image + ) + + Raises: + ValueError: If the input list is empty + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + images = [] + original_coords = [] # Renamed from position_info to be more descriptive + to_tensor = TF.ToTensor() + + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background + if img.mode == "RGBA": + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + img = Image.alpha_composite(background, img) + + # Convert to RGB + img = img.convert("RGB") + + # Get original dimensions + width, height = img.size + + # Make the image square by padding the shorter dimension + max_dim = max(width, height) + + # Calculate padding + left = (max_dim - width) // 2 + top = (max_dim - height) // 2 + + # Calculate scale factor for resizing + scale = target_size / max_dim + + # Calculate final coordinates of original image in target space + x1 = left * scale + y1 = top * scale + x2 = (left + width) * scale + y2 = (top + height) * scale + + # Store original image coordinates and scale + original_coords.append(np.array([x1, y1, x2, y2, width, height])) + + # Create a new black square image and paste original + square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) + square_img.paste(img, (left, top)) + + # Resize to target size + square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC) + + # Convert to tensor + img_tensor = to_tensor(square_img) + images.append(img_tensor) + + # Stack all images + images = torch.stack(images) + original_coords = torch.from_numpy(np.array(original_coords)).float() + + # Add additional dimension if single image to ensure correct shape + if len(image_path_list) == 1: + if images.dim() == 3: + images = images.unsqueeze(0) + original_coords = original_coords.unsqueeze(0) + + return images, original_coords + + +def load_and_preprocess_images(image_path_list, mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background: + if img.mode == "RGBA": + # Create white background + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # Alpha composite onto the white background + img = Image.alpha_composite(background, img) + + # Now convert to "RGB" (this step assigns white for transparent areas) + img = img.convert("RGB") + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + return images diff --git a/vggt/utils/pose_enc.py b/vggt/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3b964330af0e62f4d36d332317ae00cb99b3a9 --- /dev/null +++ b/vggt/utils/pose_enc.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from .rotation import quat_to_mat, mat_to_quat + + +def extri_intri_to_pose_encoding( + extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/vggt/utils/rotation.py b/vggt/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..f972afd8414c82fa1e9ed231725fd3f9f6ebde77 --- /dev/null +++ b/vggt/utils/rotation.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/vggt/utils/visual_track.py b/vggt/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154 --- /dev/null +++ b/vggt/utils/visual_track.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import numpy as np +import os + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") diff --git a/visual_util.py b/visual_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f7b7fe7c75b64470ac8ab9b3288845a083c941 --- /dev/null +++ b/visual_util.py @@ -0,0 +1,457 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import trimesh +import gradio as gr +import numpy as np +import matplotlib +from scipy.spatial.transform import Rotation +import copy +import cv2 +import os +import requests + + +def predictions_to_glb( + predictions, + conf_thres=50.0, + filter_by_frames="all", + mask_black_bg=False, + mask_white_bg=False, + show_cam=True, + mask_sky=False, + target_dir=None, + prediction_mode="Predicted Pointmap", +) -> trimesh.Scene: + """ + Converts VGGT predictions to a 3D scene represented as a GLB file. + + Args: + predictions (dict): Dictionary containing model predictions with keys: + - world_points: 3D point coordinates (S, H, W, 3) + - world_points_conf: Confidence scores (S, H, W) + - images: Input images (S, H, W, 3) + - extrinsic: Camera extrinsic matrices (S, 3, 4) + conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0) + filter_by_frames (str): Frame filter specification (default: "all") + mask_black_bg (bool): Mask out black background pixels (default: False) + mask_white_bg (bool): Mask out white background pixels (default: False) + show_cam (bool): Include camera visualization (default: True) + mask_sky (bool): Apply sky segmentation mask (default: False) + target_dir (str): Output directory for intermediate files (default: None) + prediction_mode (str): Prediction mode selector (default: "Predicted Pointmap") + + Returns: + trimesh.Scene: Processed 3D scene containing point cloud and cameras + + Raises: + ValueError: If input predictions structure is invalid + """ + if not isinstance(predictions, dict): + raise ValueError("predictions must be a dictionary") + + if conf_thres is None: + conf_thres = 10.0 + + print("Building GLB scene") + selected_frame_idx = None + if filter_by_frames != "all" and filter_by_frames != "All": + try: + # Extract the index part before the colon + selected_frame_idx = int(filter_by_frames.split(":")[0]) + except (ValueError, IndexError): + pass + + if "Pointmap" in prediction_mode: + print("Using Pointmap Branch") + if "world_points" in predictions: + pred_world_points = predictions["world_points"] # No batch dimension to remove + pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0])) + else: + print("Warning: world_points not found in predictions, falling back to depth-based points") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) + else: + print("Using Depthmap and Camera Branch") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) + + # Get images from predictions + images = predictions["images"] + # Use extrinsic matrices instead of pred_extrinsic_list + camera_matrices = predictions["extrinsic"] + + if mask_sky: + if target_dir is not None: + import onnxruntime + + skyseg_session = None + target_dir_images = target_dir + "/images" + image_list = sorted(os.listdir(target_dir_images)) + sky_mask_list = [] + + # Get the shape of pred_world_points_conf to match + S, H, W = ( + pred_world_points_conf.shape + if hasattr(pred_world_points_conf, "shape") + else (len(images), images.shape[1], images.shape[2]) + ) + + # Download skyseg.onnx if it doesn't exist + if not os.path.exists("skyseg.onnx"): + print("Downloading skyseg.onnx...") + download_file_from_url( + "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" + ) + + for i, image_name in enumerate(image_list): + image_filepath = os.path.join(target_dir_images, image_name) + mask_filepath = os.path.join(target_dir, "sky_masks", image_name) + + # Check if mask already exists + if os.path.exists(mask_filepath): + # Load existing mask + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + else: + # Generate new mask + if skyseg_session is None: + skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") + sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) + + # Resize mask to match HΓ—W if needed + if sky_mask.shape[0] != H or sky_mask.shape[1] != W: + sky_mask = cv2.resize(sky_mask, (W, H)) + + sky_mask_list.append(sky_mask) + + # Convert list to numpy array with shape SΓ—HΓ—W + sky_mask_array = np.array(sky_mask_list) + + # Apply sky mask to confidence scores + sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) + pred_world_points_conf = pred_world_points_conf * sky_mask_binary + + if selected_frame_idx is not None: + pred_world_points = pred_world_points[selected_frame_idx][None] + pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] + images = images[selected_frame_idx][None] + camera_matrices = camera_matrices[selected_frame_idx][None] + + vertices_3d = pred_world_points.reshape(-1, 3) + # Handle different image formats - check if images need transposing + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + colors_rgb = np.transpose(images, (0, 2, 3, 1)) + else: # Assume already in NHWC format + colors_rgb = images + colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + + conf = pred_world_points_conf.reshape(-1) + # Convert percentage threshold to actual confidence value + if conf_thres == 0.0: + conf_threshold = 0.0 + else: + conf_threshold = np.percentile(conf, conf_thres) + + conf_mask = (conf >= conf_threshold) & (conf > 1e-5) + + if mask_black_bg: + black_bg_mask = colors_rgb.sum(axis=1) >= 16 + conf_mask = conf_mask & black_bg_mask + + if mask_white_bg: + # Filter out white background pixels (RGB values close to white) + # Consider pixels white if all RGB values are above 240 + white_bg_mask = ~((colors_rgb[:, 0] > 240) & (colors_rgb[:, 1] > 240) & (colors_rgb[:, 2] > 240)) + conf_mask = conf_mask & white_bg_mask + + vertices_3d = vertices_3d[conf_mask] + colors_rgb = colors_rgb[conf_mask] + + if vertices_3d is None or np.asarray(vertices_3d).size == 0: + vertices_3d = np.array([[1, 0, 0]]) + colors_rgb = np.array([[255, 255, 255]]) + scene_scale = 1 + else: + # Calculate the 5th and 95th percentiles along each axis + lower_percentile = np.percentile(vertices_3d, 5, axis=0) + upper_percentile = np.percentile(vertices_3d, 95, axis=0) + + # Calculate the diagonal length of the percentile bounding box + scene_scale = np.linalg.norm(upper_percentile - lower_percentile) + + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + + # Initialize a 3D scene + scene_3d = trimesh.Scene() + + # Add point cloud data to the scene + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + + scene_3d.add_geometry(point_cloud_data) + + # Prepare 4x4 matrices for camera extrinsics + num_cameras = len(camera_matrices) + extrinsics_matrices = np.zeros((num_cameras, 4, 4)) + extrinsics_matrices[:, :3, :4] = camera_matrices + extrinsics_matrices[:, 3, 3] = 1 + + if show_cam: + # Add camera models to the scene + for i in range(num_cameras): + world_to_camera = extrinsics_matrices[i] + camera_to_world = np.linalg.inv(world_to_camera) + rgba_color = colormap(i / num_cameras) + current_color = tuple(int(255 * x) for x in rgba_color[:3]) + + integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) + + # Align scene to the observation of the first camera + scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) + + print("GLB Scene built") + return scene_3d + + +def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float): + """ + Integrates a fake camera mesh into the 3D scene. + + Args: + scene (trimesh.Scene): The 3D scene to add the camera model. + transform (np.ndarray): Transformation matrix for camera positioning. + face_colors (tuple): Color of the camera face. + scene_scale (float): Scale of the scene. + """ + + cam_width = scene_scale * 0.05 + cam_height = scene_scale * 0.1 + + # Create cone shape for camera + rot_45_degree = np.eye(4) + rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() + rot_45_degree[2, 3] = -cam_height + + opengl_transform = get_opengl_conversion_matrix() + # Combine transformations + complete_transform = transform @ opengl_transform @ rot_45_degree + camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) + + # Generate mesh for the camera + slight_rotation = np.eye(4) + slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() + + vertices_combined = np.concatenate( + [ + camera_cone_shape.vertices, + 0.95 * camera_cone_shape.vertices, + transform_points(slight_rotation, camera_cone_shape.vertices), + ] + ) + vertices_transformed = transform_points(complete_transform, vertices_combined) + + mesh_faces = compute_camera_faces(camera_cone_shape) + + # Add the camera mesh to the scene + camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) + camera_mesh.visual.face_colors[:, :3] = face_colors + scene.add_geometry(camera_mesh) + + +def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray) -> trimesh.Scene: + """ + Aligns the 3D scene based on the extrinsics of the first camera. + + Args: + scene_3d (trimesh.Scene): The 3D scene to be aligned. + extrinsics_matrices (np.ndarray): Camera extrinsic matrices. + + Returns: + trimesh.Scene: Aligned 3D scene. + """ + # Set transformations for scene alignment + opengl_conversion_matrix = get_opengl_conversion_matrix() + + # Rotation matrix for alignment (180 degrees around the y-axis) + align_rotation = np.eye(4) + align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() + + # Apply transformation + initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation + scene_3d.apply_transform(initial_transformation) + return scene_3d + + +def get_opengl_conversion_matrix() -> np.ndarray: + """ + Constructs and returns the OpenGL conversion matrix. + + Returns: + numpy.ndarray: A 4x4 OpenGL conversion matrix. + """ + # Create an identity matrix + matrix = np.identity(4) + + # Flip the y and z axes + matrix[1, 1] = -1 + matrix[2, 2] = -1 + + return matrix + + +def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray: + """ + Applies a 4x4 transformation to a set of points. + + Args: + transformation (np.ndarray): Transformation matrix. + points (np.ndarray): Points to be transformed. + dim (int, optional): Dimension for reshaping the result. + + Returns: + np.ndarray: Transformed points. + """ + points = np.asarray(points) + initial_shape = points.shape[:-1] + dim = dim or points.shape[-1] + + # Apply transformation + transformation = transformation.swapaxes(-1, -2) # Transpose the transformation matrix + points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] + + # Reshape the result + result = points[..., :dim].reshape(*initial_shape, dim) + return result + + +def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray: + """ + Computes the faces for the camera mesh. + + Args: + cone_shape (trimesh.Trimesh): The shape of the camera cone. + + Returns: + np.ndarray: Array of faces for the camera mesh. + """ + # Create pseudo cameras + faces_list = [] + num_vertices_cone = len(cone_shape.vertices) + + for face in cone_shape.faces: + if 0 in face: + continue + v1, v2, v3 = face + v1_offset, v2_offset, v3_offset = face + num_vertices_cone + v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone + + faces_list.extend( + [ + (v1, v2, v2_offset), + (v1, v1_offset, v3), + (v3_offset, v2, v3), + (v1, v2, v2_offset_2), + (v1, v1_offset_2, v3), + (v3_offset_2, v2, v3), + ] + ) + + faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] + return np.array(faces_list) + + +def segment_sky(image_path, onnx_session, mask_filename=None): + """ + Segments sky from an image using an ONNX model. + Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing + + Args: + image_path: Path to input image + onnx_session: ONNX runtime session with loaded model + mask_filename: Path to save the output mask + + Returns: + np.ndarray: Binary mask where 255 indicates non-sky regions + """ + + assert mask_filename is not None + image = cv2.imread(image_path) + + result_map = run_skyseg(onnx_session, [320, 320], image) + # resize the result_map to the original image size + result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) + + # Fix: Invert the mask so that 255 = non-sky, 0 = sky + # The model outputs low values for sky, high values for non-sky + output_mask = np.zeros_like(result_map_original) + output_mask[result_map_original < 32] = 255 # Use threshold of 32 + + os.makedirs(os.path.dirname(mask_filename), exist_ok=True) + cv2.imwrite(mask_filename, output_mask) + return output_mask + + +def run_skyseg(onnx_session, input_size, image): + """ + Runs sky segmentation inference using ONNX model. + + Args: + onnx_session: ONNX runtime session + input_size: Target size for model input (width, height) + image: Input image in BGR format + + Returns: + np.ndarray: Segmentation mask + """ + + # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast + temp_image = copy.deepcopy(image) + resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) + x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) + x = np.array(x, dtype=np.float32) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x = (x / 255 - mean) / std + x = x.transpose(2, 0, 1) + x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") + + # Inference + input_name = onnx_session.get_inputs()[0].name + output_name = onnx_session.get_outputs()[0].name + onnx_result = onnx_session.run([output_name], {input_name: x}) + + # Post process + onnx_result = np.array(onnx_result).squeeze() + min_value = np.min(onnx_result) + max_value = np.max(onnx_result) + onnx_result = (onnx_result - min_value) / (max_value - min_value) + onnx_result *= 255 + onnx_result = onnx_result.astype("uint8") + + return onnx_result + + +def download_file_from_url(url, filename): + """Downloads a file from a Hugging Face model repo, handling redirects.""" + try: + # Get the redirect URL + response = requests.get(url, allow_redirects=False) + response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx) + + if response.status_code == 302: # Expecting a redirect + redirect_url = response.headers["Location"] + response = requests.get(redirect_url, stream=True) + response.raise_for_status() + else: + print(f"Unexpected status code: {response.status_code}") + return + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Downloaded {filename} successfully.") + + except requests.exceptions.RequestException as e: + print(f"Error downloading file: {e}")