Ruicheng commited on
Commit
201ab98
·
0 Parent(s):

Initial commit for HF

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. .gitignore +423 -0
  3. CHANGELOG.md +32 -0
  4. CODE_OF_CONDUCT.md +9 -0
  5. LICENSE +224 -0
  6. README.md +14 -0
  7. SECURITY.md +41 -0
  8. SUPPORT.md +25 -0
  9. app.py +298 -0
  10. assets/overview_simplified.png +3 -0
  11. assets/panorama_pipeline.png +3 -0
  12. baselines/da_v2.py +88 -0
  13. baselines/da_v2_metric.py +99 -0
  14. baselines/metric3d_v2.py +117 -0
  15. baselines/moge.py +83 -0
  16. configs/eval/all_benchmarks.json +78 -0
  17. configs/eval/benchmarks/ddad.json +9 -0
  18. configs/eval/benchmarks/diode.json +9 -0
  19. configs/eval/benchmarks/eth3d.json +10 -0
  20. configs/eval/benchmarks/gso.json +8 -0
  21. configs/eval/benchmarks/hammer.json +10 -0
  22. configs/eval/benchmarks/ibims-1.json +10 -0
  23. configs/eval/benchmarks/kitti.json +9 -0
  24. configs/eval/benchmarks/nyu.json +8 -0
  25. configs/eval/benchmarks/sintel.json +10 -0
  26. configs/eval/benchmarks/spring.json +9 -0
  27. configs/train/v1.json +77 -0
  28. docs/eval.md +77 -0
  29. docs/train.md +181 -0
  30. example_images/01_HouseIndoor.jpg +3 -0
  31. example_images/02_Office.jpg +3 -0
  32. example_images/03_Traffic.jpg +3 -0
  33. example_images/04_BunnyCake.jpg +3 -0
  34. example_images/05_Mountain.jpg +3 -0
  35. example_images/06_MaitreyaBuddha.png +3 -0
  36. example_images/07_Breads.jpg +3 -0
  37. example_images/08_CatGirl.png +3 -0
  38. example_images/09_Restaurant.jpg +3 -0
  39. example_images/10_MedievalVillage.jpg +3 -0
  40. example_images/11_Room.jpg +3 -0
  41. example_images/12_StylizedHouses.jpg +3 -0
  42. example_images/panorama/Braunschweig_Panoram.jpg +3 -0
  43. moge/__init__.py +0 -0
  44. moge/model/__init__.py +18 -0
  45. moge/model/dinov2/__init__.py +6 -0
  46. moge/model/dinov2/hub/__init__.py +4 -0
  47. moge/model/dinov2/hub/backbones.py +156 -0
  48. moge/model/dinov2/hub/utils.py +39 -0
  49. moge/model/dinov2/layers/__init__.py +11 -0
  50. moge/model/dinov2/layers/attention.py +89 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.7z filter=lfs diff=lfs merge=lfs -text
4
+ *.arrow filter=lfs diff=lfs merge=lfs -text
5
+ *.bin filter=lfs diff=lfs merge=lfs -text
6
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
7
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
8
+ *.ftz filter=lfs diff=lfs merge=lfs -text
9
+ *.gz filter=lfs diff=lfs merge=lfs -text
10
+ *.h5 filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
13
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
14
+ *.model filter=lfs diff=lfs merge=lfs -text
15
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
16
+ *.npy filter=lfs diff=lfs merge=lfs -text
17
+ *.npz filter=lfs diff=lfs merge=lfs -text
18
+ *.onnx filter=lfs diff=lfs merge=lfs -text
19
+ *.ot filter=lfs diff=lfs merge=lfs -text
20
+ *.parquet filter=lfs diff=lfs merge=lfs -text
21
+ *.pb filter=lfs diff=lfs merge=lfs -text
22
+ *.pickle filter=lfs diff=lfs merge=lfs -text
23
+ *.pkl filter=lfs diff=lfs merge=lfs -text
24
+ *.pt filter=lfs diff=lfs merge=lfs -text
25
+ *.pth filter=lfs diff=lfs merge=lfs -text
26
+ *.rar filter=lfs diff=lfs merge=lfs -text
27
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
28
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar filter=lfs diff=lfs merge=lfs -text
31
+ *.tflite filter=lfs diff=lfs merge=lfs -text
32
+ *.tgz filter=lfs diff=lfs merge=lfs -text
33
+ *.wasm filter=lfs diff=lfs merge=lfs -text
34
+ *.xz filter=lfs diff=lfs merge=lfs -text
35
+ *.zip filter=lfs diff=lfs merge=lfs -text
36
+ *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Ignore Visual Studio temporary files, build results, and
2
+ ## files generated by popular Visual Studio add-ons.
3
+ ##
4
+ ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
5
+
6
+ # User-specific files
7
+ *.rsuser
8
+ *.suo
9
+ *.user
10
+ *.userosscache
11
+ *.sln.docstates
12
+
13
+ # User-specific files (MonoDevelop/Xamarin Studio)
14
+ *.userprefs
15
+
16
+ # Mono auto generated files
17
+ mono_crash.*
18
+
19
+ # Build results
20
+ [Dd]ebug/
21
+ [Dd]ebugPublic/
22
+ [Rr]elease/
23
+ [Rr]eleases/
24
+ x64/
25
+ x86/
26
+ [Ww][Ii][Nn]32/
27
+ [Aa][Rr][Mm]/
28
+ [Aa][Rr][Mm]64/
29
+ bld/
30
+ [Bb]in/
31
+ [Oo]bj/
32
+ [Ll]og/
33
+ [Ll]ogs/
34
+
35
+ # Visual Studio 2015/2017 cache/options directory
36
+ .vs/
37
+ # Uncomment if you have tasks that create the project's static files in wwwroot
38
+ #wwwroot/
39
+
40
+ # Visual Studio 2017 auto generated files
41
+ Generated\ Files/
42
+
43
+ # MSTest test Results
44
+ [Tt]est[Rr]esult*/
45
+ [Bb]uild[Ll]og.*
46
+
47
+ # NUnit
48
+ *.VisualState.xml
49
+ TestResult.xml
50
+ nunit-*.xml
51
+
52
+ # Build Results of an ATL Project
53
+ [Dd]ebugPS/
54
+ [Rr]eleasePS/
55
+ dlldata.c
56
+
57
+ # Benchmark Results
58
+ BenchmarkDotNet.Artifacts/
59
+
60
+ # .NET Core
61
+ project.lock.json
62
+ project.fragment.lock.json
63
+ artifacts/
64
+
65
+ # ASP.NET Scaffolding
66
+ ScaffoldingReadMe.txt
67
+
68
+ # StyleCop
69
+ StyleCopReport.xml
70
+
71
+ # Files built by Visual Studio
72
+ *_i.c
73
+ *_p.c
74
+ *_h.h
75
+ *.ilk
76
+ *.meta
77
+ *.obj
78
+ *.iobj
79
+ *.pch
80
+ *.pdb
81
+ *.ipdb
82
+ *.pgc
83
+ *.pgd
84
+ *.rsp
85
+ *.sbr
86
+ *.tlb
87
+ *.tli
88
+ *.tlh
89
+ *.tmp
90
+ *.tmp_proj
91
+ *_wpftmp.csproj
92
+ *.log
93
+ *.tlog
94
+ *.vspscc
95
+ *.vssscc
96
+ .builds
97
+ *.pidb
98
+ *.svclog
99
+ *.scc
100
+
101
+ # Chutzpah Test files
102
+ _Chutzpah*
103
+
104
+ # Visual C++ cache files
105
+ ipch/
106
+ *.aps
107
+ *.ncb
108
+ *.opendb
109
+ *.opensdf
110
+ *.sdf
111
+ *.cachefile
112
+ *.VC.db
113
+ *.VC.VC.opendb
114
+
115
+ # Visual Studio profiler
116
+ *.psess
117
+ *.vsp
118
+ *.vspx
119
+ *.sap
120
+
121
+ # Visual Studio Trace Files
122
+ *.e2e
123
+
124
+ # TFS 2012 Local Workspace
125
+ $tf/
126
+
127
+ # Guidance Automation Toolkit
128
+ *.gpState
129
+
130
+ # ReSharper is a .NET coding add-in
131
+ _ReSharper*/
132
+ *.[Rr]e[Ss]harper
133
+ *.DotSettings.user
134
+
135
+ # TeamCity is a build add-in
136
+ _TeamCity*
137
+
138
+ # DotCover is a Code Coverage Tool
139
+ *.dotCover
140
+
141
+ # AxoCover is a Code Coverage Tool
142
+ .axoCover/*
143
+ !.axoCover/settings.json
144
+
145
+ # Coverlet is a free, cross platform Code Coverage Tool
146
+ coverage*.json
147
+ coverage*.xml
148
+ coverage*.info
149
+
150
+ # Visual Studio code coverage results
151
+ *.coverage
152
+ *.coveragexml
153
+
154
+ # NCrunch
155
+ _NCrunch_*
156
+ .*crunch*.local.xml
157
+ nCrunchTemp_*
158
+
159
+ # MightyMoose
160
+ *.mm.*
161
+ AutoTest.Net/
162
+
163
+ # Web workbench (sass)
164
+ .sass-cache/
165
+
166
+ # Installshield output folder
167
+ [Ee]xpress/
168
+
169
+ # DocProject is a documentation generator add-in
170
+ DocProject/buildhelp/
171
+ DocProject/Help/*.HxT
172
+ DocProject/Help/*.HxC
173
+ DocProject/Help/*.hhc
174
+ DocProject/Help/*.hhk
175
+ DocProject/Help/*.hhp
176
+ DocProject/Help/Html2
177
+ DocProject/Help/html
178
+
179
+ # Click-Once directory
180
+ publish/
181
+
182
+ # Publish Web Output
183
+ *.[Pp]ublish.xml
184
+ *.azurePubxml
185
+ # Note: Comment the next line if you want to checkin your web deploy settings,
186
+ # but database connection strings (with potential passwords) will be unencrypted
187
+ *.pubxml
188
+ *.publishproj
189
+
190
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
191
+ # checkin your Azure Web App publish settings, but sensitive information contained
192
+ # in these scripts will be unencrypted
193
+ PublishScripts/
194
+
195
+ # NuGet Packages
196
+ *.nupkg
197
+ # NuGet Symbol Packages
198
+ *.snupkg
199
+ # The packages folder can be ignored because of Package Restore
200
+ **/[Pp]ackages/*
201
+ # except build/, which is used as an MSBuild target.
202
+ !**/[Pp]ackages/build/
203
+ # Uncomment if necessary however generally it will be regenerated when needed
204
+ #!**/[Pp]ackages/repositories.config
205
+ # NuGet v3's project.json files produces more ignorable files
206
+ *.nuget.props
207
+ *.nuget.targets
208
+
209
+ # Microsoft Azure Build Output
210
+ csx/
211
+ *.build.csdef
212
+
213
+ # Microsoft Azure Emulator
214
+ ecf/
215
+ rcf/
216
+
217
+ # Windows Store app package directories and files
218
+ AppPackages/
219
+ BundleArtifacts/
220
+ Package.StoreAssociation.xml
221
+ _pkginfo.txt
222
+ *.appx
223
+ *.appxbundle
224
+ *.appxupload
225
+
226
+ # Visual Studio cache files
227
+ # files ending in .cache can be ignored
228
+ *.[Cc]ache
229
+ # but keep track of directories ending in .cache
230
+ !?*.[Cc]ache/
231
+
232
+ # Others
233
+ ClientBin/
234
+ ~$*
235
+ *~
236
+ *.dbmdl
237
+ *.dbproj.schemaview
238
+ *.jfm
239
+ *.pfx
240
+ *.publishsettings
241
+ orleans.codegen.cs
242
+
243
+ # Including strong name files can present a security risk
244
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
245
+ #*.snk
246
+
247
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
248
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
249
+ #bower_components/
250
+
251
+ # RIA/Silverlight projects
252
+ Generated_Code/
253
+
254
+ # Backup & report files from converting an old project file
255
+ # to a newer Visual Studio version. Backup files are not needed,
256
+ # because we have git ;-)
257
+ _UpgradeReport_Files/
258
+ Backup*/
259
+ UpgradeLog*.XML
260
+ UpgradeLog*.htm
261
+ ServiceFabricBackup/
262
+ *.rptproj.bak
263
+
264
+ # SQL Server files
265
+ *.mdf
266
+ *.ldf
267
+ *.ndf
268
+
269
+ # Business Intelligence projects
270
+ *.rdl.data
271
+ *.bim.layout
272
+ *.bim_*.settings
273
+ *.rptproj.rsuser
274
+ *- [Bb]ackup.rdl
275
+ *- [Bb]ackup ([0-9]).rdl
276
+ *- [Bb]ackup ([0-9][0-9]).rdl
277
+
278
+ # Microsoft Fakes
279
+ FakesAssemblies/
280
+
281
+ # GhostDoc plugin setting file
282
+ *.GhostDoc.xml
283
+
284
+ # Node.js Tools for Visual Studio
285
+ .ntvs_analysis.dat
286
+ node_modules/
287
+
288
+ # Visual Studio 6 build log
289
+ *.plg
290
+
291
+ # Visual Studio 6 workspace options file
292
+ *.opt
293
+
294
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
295
+ *.vbw
296
+
297
+ # Visual Studio 6 auto-generated project file (contains which files were open etc.)
298
+ *.vbp
299
+
300
+ # Visual Studio 6 workspace and project file (working project files containing files to include in project)
301
+ *.dsw
302
+ *.dsp
303
+
304
+ # Visual Studio 6 technical files
305
+ *.ncb
306
+ *.aps
307
+
308
+ # Visual Studio LightSwitch build output
309
+ **/*.HTMLClient/GeneratedArtifacts
310
+ **/*.DesktopClient/GeneratedArtifacts
311
+ **/*.DesktopClient/ModelManifest.xml
312
+ **/*.Server/GeneratedArtifacts
313
+ **/*.Server/ModelManifest.xml
314
+ _Pvt_Extensions
315
+
316
+ # Paket dependency manager
317
+ .paket/paket.exe
318
+ paket-files/
319
+
320
+ # FAKE - F# Make
321
+ .fake/
322
+
323
+ # CodeRush personal settings
324
+ .cr/personal
325
+
326
+ # Python Tools for Visual Studio (PTVS)
327
+ __pycache__/
328
+ *.pyc
329
+
330
+ # Cake - Uncomment if you are using it
331
+ # tools/**
332
+ # !tools/packages.config
333
+
334
+ # Tabs Studio
335
+ *.tss
336
+
337
+ # Telerik's JustMock configuration file
338
+ *.jmconfig
339
+
340
+ # BizTalk build output
341
+ *.btp.cs
342
+ *.btm.cs
343
+ *.odx.cs
344
+ *.xsd.cs
345
+
346
+ # OpenCover UI analysis results
347
+ OpenCover/
348
+
349
+ # Azure Stream Analytics local run output
350
+ ASALocalRun/
351
+
352
+ # MSBuild Binary and Structured Log
353
+ *.binlog
354
+
355
+ # NVidia Nsight GPU debugger configuration file
356
+ *.nvuser
357
+
358
+ # MFractors (Xamarin productivity tool) working folder
359
+ .mfractor/
360
+
361
+ # Local History for Visual Studio
362
+ .localhistory/
363
+
364
+ # Visual Studio History (VSHistory) files
365
+ .vshistory/
366
+
367
+ # BeatPulse healthcheck temp database
368
+ healthchecksdb
369
+
370
+ # Backup folder for Package Reference Convert tool in Visual Studio 2017
371
+ MigrationBackup/
372
+
373
+ # Ionide (cross platform F# VS Code tools) working folder
374
+ .ionide/
375
+
376
+ # Fody - auto-generated XML schema
377
+ FodyWeavers.xsd
378
+
379
+ # VS Code files for those working on multiple tools
380
+ .vscode/*
381
+ !.vscode/settings.json
382
+ !.vscode/tasks.json
383
+ !.vscode/launch.json
384
+ !.vscode/extensions.json
385
+ *.code-workspace
386
+
387
+ # Local History for Visual Studio Code
388
+ .history/
389
+
390
+ # Windows Installer files from build outputs
391
+ *.cab
392
+ *.msi
393
+ *.msix
394
+ *.msm
395
+ *.msp
396
+
397
+ # JetBrains Rider
398
+ *.sln.iml
399
+
400
+ # Python
401
+ *.egg-info/
402
+ /build
403
+
404
+ # MoGe
405
+ /data*
406
+ /download
407
+ /extract
408
+ /debug
409
+ /workspace
410
+ /mlruns
411
+ /infer_output
412
+ /video_output
413
+ /eval_output
414
+ /.blobcache
415
+ /test_images
416
+ /test_videos
417
+ /vis
418
+ /videos
419
+ /blobmnt
420
+ /eval_dump
421
+ /pretrained
422
+ /.gradio
423
+ /tmp
CHANGELOG.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 2024-11-28
2
+ ### Added
3
+ - Supported user-provided camera FOV. See [scripts/infer.py](scripts/infer.py) --fov_x.
4
+ - Related issues: [#25](https://github.com/microsoft/MoGe/issues/25) and [#24](https://github.com/microsoft/MoGe/issues/24).
5
+ - Added inference scripts for panorama images. See [scripts/infer_panorama.py](scripts/infer_panorama.py).
6
+ - Related issue: [#19](https://github.com/microsoft/MoGe/issues/19).
7
+
8
+ ### Fixed
9
+ - Suppressed unnecessary numpy runtime warnings.
10
+ - Specified recommended versions of requirements.
11
+ - Related issue: [#21](https://github.com/microsoft/MoGe/issues/21).
12
+
13
+ ### Changed
14
+ - Moved `app.py` and `infer.py` to [scripts/](scripts/)
15
+ - Improved edge removal.
16
+
17
+ ## 2025-03-18
18
+ ### Added
19
+ - Training and evaluation code. See [docs/train.md](docs/train.md) and [docs/eval.md](docs/eval.md).
20
+ - Supported installation via pip. Thanks to @fabiencastan and @jgoueslard
21
+ for commits in the [#47](https://github.com/microsoft/MoGe/pull/47)
22
+ - Supported command-line usage when installed.
23
+
24
+ ### Changed
25
+ - Moved `scripts/` into `moge/` for package installation and command-line usage.
26
+ - Renamed `moge.model.moge_model` to `moge.model.v1` for version management.
27
+ Now you can import the model class through `from moge.model.v1 import MoGeModel` or `from moge.model import import_model_class_by_version; MoGeModel = import_model_class_by_version('v1')`.
28
+ - Exposed `num_tokens` parameter in MoGe model.
29
+
30
+ ## 2025-06-10
31
+ ### Added
32
+ - Released MoGe-2.
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
LICENSE ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
22
+
23
+
24
+ Apache License
25
+ Version 2.0, January 2004
26
+ http://www.apache.org/licenses/
27
+
28
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
29
+
30
+ 1. Definitions.
31
+
32
+ "License" shall mean the terms and conditions for use, reproduction,
33
+ and distribution as defined by Sections 1 through 9 of this document.
34
+
35
+ "Licensor" shall mean the copyright owner or entity authorized by
36
+ the copyright owner that is granting the License.
37
+
38
+ "Legal Entity" shall mean the union of the acting entity and all
39
+ other entities that control, are controlled by, or are under common
40
+ control with that entity. For the purposes of this definition,
41
+ "control" means (i) the power, direct or indirect, to cause the
42
+ direction or management of such entity, whether by contract or
43
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
44
+ outstanding shares, or (iii) beneficial ownership of such entity.
45
+
46
+ "You" (or "Your") shall mean an individual or Legal Entity
47
+ exercising permissions granted by this License.
48
+
49
+ "Source" form shall mean the preferred form for making modifications,
50
+ including but not limited to software source code, documentation
51
+ source, and configuration files.
52
+
53
+ "Object" form shall mean any form resulting from mechanical
54
+ transformation or translation of a Source form, including but
55
+ not limited to compiled object code, generated documentation,
56
+ and conversions to other media types.
57
+
58
+ "Work" shall mean the work of authorship, whether in Source or
59
+ Object form, made available under the License, as indicated by a
60
+ copyright notice that is included in or attached to the work
61
+ (an example is provided in the Appendix below).
62
+
63
+ "Derivative Works" shall mean any work, whether in Source or Object
64
+ form, that is based on (or derived from) the Work and for which the
65
+ editorial revisions, annotations, elaborations, or other modifications
66
+ represent, as a whole, an original work of authorship. For the purposes
67
+ of this License, Derivative Works shall not include works that remain
68
+ separable from, or merely link (or bind by name) to the interfaces of,
69
+ the Work and Derivative Works thereof.
70
+
71
+ "Contribution" shall mean any work of authorship, including
72
+ the original version of the Work and any modifications or additions
73
+ to that Work or Derivative Works thereof, that is intentionally
74
+ submitted to Licensor for inclusion in the Work by the copyright owner
75
+ or by an individual or Legal Entity authorized to submit on behalf of
76
+ the copyright owner. For the purposes of this definition, "submitted"
77
+ means any form of electronic, verbal, or written communication sent
78
+ to the Licensor or its representatives, including but not limited to
79
+ communication on electronic mailing lists, source code control systems,
80
+ and issue tracking systems that are managed by, or on behalf of, the
81
+ Licensor for the purpose of discussing and improving the Work, but
82
+ excluding communication that is conspicuously marked or otherwise
83
+ designated in writing by the copyright owner as "Not a Contribution."
84
+
85
+ "Contributor" shall mean Licensor and any individual or Legal Entity
86
+ on behalf of whom a Contribution has been received by Licensor and
87
+ subsequently incorporated within the Work.
88
+
89
+ 2. Grant of Copyright License. Subject to the terms and conditions of
90
+ this License, each Contributor hereby grants to You a perpetual,
91
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
92
+ copyright license to reproduce, prepare Derivative Works of,
93
+ publicly display, publicly perform, sublicense, and distribute the
94
+ Work and such Derivative Works in Source or Object form.
95
+
96
+ 3. Grant of Patent License. Subject to the terms and conditions of
97
+ this License, each Contributor hereby grants to You a perpetual,
98
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
99
+ (except as stated in this section) patent license to make, have made,
100
+ use, offer to sell, sell, import, and otherwise transfer the Work,
101
+ where such license applies only to those patent claims licensable
102
+ by such Contributor that are necessarily infringed by their
103
+ Contribution(s) alone or by combination of their Contribution(s)
104
+ with the Work to which such Contribution(s) was submitted. If You
105
+ institute patent litigation against any entity (including a
106
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
107
+ or a Contribution incorporated within the Work constitutes direct
108
+ or contributory patent infringement, then any patent licenses
109
+ granted to You under this License for that Work shall terminate
110
+ as of the date such litigation is filed.
111
+
112
+ 4. Redistribution. You may reproduce and distribute copies of the
113
+ Work or Derivative Works thereof in any medium, with or without
114
+ modifications, and in Source or Object form, provided that You
115
+ meet the following conditions:
116
+
117
+ (a) You must give any other recipients of the Work or
118
+ Derivative Works a copy of this License; and
119
+
120
+ (b) You must cause any modified files to carry prominent notices
121
+ stating that You changed the files; and
122
+
123
+ (c) You must retain, in the Source form of any Derivative Works
124
+ that You distribute, all copyright, patent, trademark, and
125
+ attribution notices from the Source form of the Work,
126
+ excluding those notices that do not pertain to any part of
127
+ the Derivative Works; and
128
+
129
+ (d) If the Work includes a "NOTICE" text file as part of its
130
+ distribution, then any Derivative Works that You distribute must
131
+ include a readable copy of the attribution notices contained
132
+ within such NOTICE file, excluding those notices that do not
133
+ pertain to any part of the Derivative Works, in at least one
134
+ of the following places: within a NOTICE text file distributed
135
+ as part of the Derivative Works; within the Source form or
136
+ documentation, if provided along with the Derivative Works; or,
137
+ within a display generated by the Derivative Works, if and
138
+ wherever such third-party notices normally appear. The contents
139
+ of the NOTICE file are for informational purposes only and
140
+ do not modify the License. You may add Your own attribution
141
+ notices within Derivative Works that You distribute, alongside
142
+ or as an addendum to the NOTICE text from the Work, provided
143
+ that such additional attribution notices cannot be construed
144
+ as modifying the License.
145
+
146
+ You may add Your own copyright statement to Your modifications and
147
+ may provide additional or different license terms and conditions
148
+ for use, reproduction, or distribution of Your modifications, or
149
+ for any such Derivative Works as a whole, provided Your use,
150
+ reproduction, and distribution of the Work otherwise complies with
151
+ the conditions stated in this License.
152
+
153
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
154
+ any Contribution intentionally submitted for inclusion in the Work
155
+ by You to the Licensor shall be under the terms and conditions of
156
+ this License, without any additional terms or conditions.
157
+ Notwithstanding the above, nothing herein shall supersede or modify
158
+ the terms of any separate license agreement you may have executed
159
+ with Licensor regarding such Contributions.
160
+
161
+ 6. Trademarks. This License does not grant permission to use the trade
162
+ names, trademarks, service marks, or product names of the Licensor,
163
+ except as required for reasonable and customary use in describing the
164
+ origin of the Work and reproducing the content of the NOTICE file.
165
+
166
+ 7. Disclaimer of Warranty. Unless required by applicable law or
167
+ agreed to in writing, Licensor provides the Work (and each
168
+ Contributor provides its Contributions) on an "AS IS" BASIS,
169
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
170
+ implied, including, without limitation, any warranties or conditions
171
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
172
+ PARTICULAR PURPOSE. You are solely responsible for determining the
173
+ appropriateness of using or redistributing the Work and assume any
174
+ risks associated with Your exercise of permissions under this License.
175
+
176
+ 8. Limitation of Liability. In no event and under no legal theory,
177
+ whether in tort (including negligence), contract, or otherwise,
178
+ unless required by applicable law (such as deliberate and grossly
179
+ negligent acts) or agreed to in writing, shall any Contributor be
180
+ liable to You for damages, including any direct, indirect, special,
181
+ incidental, or consequential damages of any character arising as a
182
+ result of this License or out of the use or inability to use the
183
+ Work (including but not limited to damages for loss of goodwill,
184
+ work stoppage, computer failure or malfunction, or any and all
185
+ other commercial damages or losses), even if such Contributor
186
+ has been advised of the possibility of such damages.
187
+
188
+ 9. Accepting Warranty or Additional Liability. While redistributing
189
+ the Work or Derivative Works thereof, You may choose to offer,
190
+ and charge a fee for, acceptance of support, warranty, indemnity,
191
+ or other liability obligations and/or rights consistent with this
192
+ License. However, in accepting such obligations, You may act only
193
+ on Your own behalf and on Your sole responsibility, not on behalf
194
+ of any other Contributor, and only if You agree to indemnify,
195
+ defend, and hold each Contributor harmless for any liability
196
+ incurred by, or claims asserted against, such Contributor by reason
197
+ of your accepting any such warranty or additional liability.
198
+
199
+ END OF TERMS AND CONDITIONS
200
+
201
+ APPENDIX: How to apply the Apache License to your work.
202
+
203
+ To apply the Apache License to your work, attach the following
204
+ boilerplate notice, with the fields enclosed by brackets "[]"
205
+ replaced with your own identifying information. (Don't include
206
+ the brackets!) The text should be enclosed in the appropriate
207
+ comment syntax for the file format. We also recommend that a
208
+ file or class name and description of purpose be included on the
209
+ same "printed page" as the copyright notice for easier
210
+ identification within third-party archives.
211
+
212
+ Copyright [yyyy] [name of copyright owner]
213
+
214
+ Licensed under the Apache License, Version 2.0 (the "License");
215
+ you may not use this file except in compliance with the License.
216
+ You may obtain a copy of the License at
217
+
218
+ http://www.apache.org/licenses/LICENSE-2.0
219
+
220
+ Unless required by applicable law or agreed to in writing, software
221
+ distributed under the License is distributed on an "AS IS" BASIS,
222
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
223
+ See the License for the specific language governing permissions and
224
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MoGe 2
3
+ emoji: 🚀
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.33.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Monocular metric-scale geometry estimation
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
SUPPORT.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: The maintainer of this repo has not yet edited this file
2
+
3
+ **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4
+
5
+ - **No CSS support:** Fill out this template with information about how to file issues and get help.
6
+ - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7
+ - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8
+
9
+ *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10
+
11
+ # Support
12
+
13
+ ## How to file issues and get help
14
+
15
+ This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16
+ issues before filing new issues to avoid duplicates. For new issues, file your bug or
17
+ feature request as a new Issue.
18
+
19
+ For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20
+ FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21
+ CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22
+
23
+ ## Microsoft Support Policy
24
+
25
+ Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ import sys
4
+ from pathlib import Path
5
+ import time
6
+ import uuid
7
+ import tempfile
8
+ import itertools
9
+ from typing import *
10
+ import atexit
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ import shutil
13
+
14
+ import click
15
+
16
+
17
+ @click.command(help='Web demo')
18
+ @click.option('--share', is_flag=True, help='Whether to run the app in shared mode.')
19
+ @click.option('--pretrained', 'pretrained_model_name_or_path', default=None, help='The name or path of the pre-trained model.')
20
+ @click.option('--version', 'model_version', default='v2', help='The version of the model.')
21
+ def main(share: bool, pretrained_model_name_or_path: str, model_version: str, use_fp16: bool = True):
22
+ print("Import modules...")
23
+ # Lazy import
24
+ import cv2
25
+ import torch
26
+ import numpy as np
27
+ import trimesh
28
+ import trimesh.visual
29
+ from PIL import Image
30
+ import gradio as gr
31
+ try:
32
+ import spaces # This is for deployment at huggingface.co/spaces
33
+ HUGGINFACE_SPACES_INSTALLED = True
34
+ except ImportError:
35
+ HUGGINFACE_SPACES_INSTALLED = False
36
+
37
+ import utils3d
38
+ from moge.utils.io import write_normal
39
+ from moge.utils.vis import colorize_depth, colorize_normal
40
+ from moge.model import import_model_class_by_version
41
+ from moge.utils.geometry_numpy import depth_occlusion_edge_numpy
42
+ from moge.utils.tools import timeit
43
+
44
+ print("Load model...")
45
+ if pretrained_model_name_or_path is None:
46
+ DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = {
47
+ "v1": "Ruicheng/moge-vitl",
48
+ "v2": "Ruicheng/moge-2-vitl-normal",
49
+ }
50
+ pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
51
+ model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval()
52
+ if use_fp16:
53
+ model.half()
54
+ thread_pool_executor = ThreadPoolExecutor(max_workers=1)
55
+
56
+ def delete_later(path: Union[str, os.PathLike], delay: int = 300):
57
+ def _delete():
58
+ try:
59
+ os.remove(path)
60
+ except FileNotFoundError:
61
+ pass
62
+ def _wait_and_delete():
63
+ time.sleep(delay)
64
+ _delete(path)
65
+ thread_pool_executor.submit(_wait_and_delete)
66
+ atexit.register(_delete)
67
+
68
+ # Inference on GPU.
69
+ @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else lambda x: x)
70
+ def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]:
71
+ image_tensor = torch.tensor(image, dtype=torch.float32 if not use_fp16 else torch.float16, device=torch.device('cuda')).permute(2, 0, 1) / 255
72
+ output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=use_fp16)
73
+ output = {k: v.cpu().numpy() for k, v in output.items()}
74
+ return output
75
+
76
+ # Full inference pipeline
77
+ def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None):
78
+ larger_size = max(image.shape[:2])
79
+ if larger_size > max_size:
80
+ scale = max_size / larger_size
81
+ image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
82
+
83
+ height, width = image.shape[:2]
84
+
85
+ resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 18}.get(resolution_level, 9)
86
+ output = run_with_gpu(image, resolution_level_int, apply_mask)
87
+
88
+ points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None)
89
+
90
+ if remove_edge:
91
+ mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=0.04)
92
+ else:
93
+ mask_cleaned = mask
94
+
95
+ results = {
96
+ **output,
97
+ 'mask_cleaned': mask_cleaned,
98
+ 'image': image
99
+ }
100
+
101
+ # depth & normal visualization
102
+ depth_vis = colorize_depth(depth)
103
+ if normal is not None:
104
+ normal_vis = colorize_normal(normal)
105
+ else:
106
+ normal_vis = gr.update(label="Normal map (not avalable for this model)")
107
+
108
+ # mesh & pointcloud
109
+ if normal is None:
110
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
111
+ points,
112
+ image.astype(np.float32) / 255,
113
+ utils3d.numpy.image_uv(width=width, height=height),
114
+ mask=mask_cleaned,
115
+ tri=True
116
+ )
117
+ vertex_normals = None
118
+ else:
119
+ faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
120
+ points,
121
+ image.astype(np.float32) / 255,
122
+ utils3d.numpy.image_uv(width=width, height=height),
123
+ normal,
124
+ mask=mask_cleaned,
125
+ tri=True
126
+ )
127
+ vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
128
+ vertex_uvs = vertex_uvs * np.array([1, -1], dtype=np.float32) + np.array([0, 1], dtype=np.float32)
129
+ if vertex_normals is not None:
130
+ vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
131
+
132
+ tempdir = Path(tempfile.gettempdir(), 'moge')
133
+ tempdir.mkdir(exist_ok=True)
134
+ output_path = Path(tempdir, request.session_hash)
135
+ shutil.rmtree(output_path, ignore_errors=True)
136
+ output_path.mkdir(exist_ok=True, parents=True)
137
+ trimesh.Trimesh(
138
+ vertices=vertices,
139
+ faces=faces,
140
+ vertex_normals=vertex_normals,
141
+ visual = trimesh.visual.texture.TextureVisuals(
142
+ uv=vertex_uvs,
143
+ material=trimesh.visual.material.PBRMaterial(
144
+ baseColorTexture=Image.fromarray(image),
145
+ metallicFactor=0.5,
146
+ roughnessFactor=1.0
147
+ )
148
+ ),
149
+ process=False
150
+ ).export(output_path / 'mesh.glb')
151
+ pointcloud = trimesh.PointCloud(
152
+ vertices=vertices,
153
+ colors=vertex_colors,
154
+ )
155
+ pointcloud.vertex_normals = vertex_normals
156
+ pointcloud.export(output_path / 'pointcloud.ply', vertex_normal=True)
157
+ trimesh.PointCloud(
158
+ vertices=vertices,
159
+ colors=vertex_colors,
160
+ ).export(output_path / 'pointcloud.glb', include_normals=True)
161
+ cv2.imwrite(str(output_path /'mask.png'), mask.astype(np.uint8) * 255)
162
+ cv2.imwrite(str(output_path / 'depth.exr'), depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
163
+ cv2.imwrite(str(output_path / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
164
+ if normal is not None:
165
+ cv2.imwrite(str(output_path / 'normal.exr'), cv2.cvtColor(normal.astype(np.float32) * np.array([1, -1, -1], dtype=np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
166
+
167
+ files = ['mesh.glb', 'pointcloud.ply', 'depth.exr', 'points.exr', 'mask.png']
168
+ if normal is not None:
169
+ files.append('normal.exr')
170
+
171
+ for f in files:
172
+ delete_later(output_path / f)
173
+
174
+ # FOV
175
+ intrinsics = results['intrinsics']
176
+ fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
177
+ fov_x, fov_y = np.rad2deg([fov_x, fov_y])
178
+
179
+ # messages
180
+ viewer_message = f'**Note:** Inference has been completed. It may take a few seconds to download the 3D model.'
181
+ if resolution_level != 'Ultra':
182
+ depth_message = f'**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.'
183
+ else:
184
+ depth_message = ""
185
+
186
+ return (
187
+ results,
188
+ depth_vis,
189
+ normal_vis,
190
+ output_path / 'pointcloud.glb',
191
+ [(output_path / f).as_posix() for f in files if (output_path / f).exists()],
192
+ f'- **Horizontal FOV: {fov_x:.1f}°**. \n - **Vertical FOV: {fov_y:.1f}°**',
193
+ viewer_message,
194
+ depth_message
195
+ )
196
+
197
+ def reset_measure(results: Dict[str, np.ndarray]):
198
+ return [results['image'], [], ""]
199
+
200
+
201
+ def measure(results: Dict[str, np.ndarray], measure_points: List[Tuple[int, int]], event: gr.SelectData):
202
+ point2d = event.index[0], event.index[1]
203
+ measure_points.append(point2d)
204
+
205
+ image = results['image'].copy()
206
+ for p in measure_points:
207
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
208
+
209
+ depth_text = ""
210
+ for i, p in enumerate(measure_points):
211
+ d = results['depth'][p[1], p[0]]
212
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
213
+
214
+ if len(measure_points) == 2:
215
+ point1, point2 = measure_points
216
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
217
+ distance = np.linalg.norm(results['points'][point1[1], point1[0]] - results['points'][point2[1], point2[0]])
218
+ measure_points = []
219
+
220
+ distance_text = f"- **Distance: {distance:.2f}m**"
221
+
222
+ text = depth_text + distance_text
223
+ return [image, measure_points, text]
224
+ else:
225
+ return [image, measure_points, depth_text]
226
+
227
+ print("Create Gradio app...")
228
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
229
+ gr.Markdown(
230
+ f'''
231
+ <div align="center">
232
+ <h1> Turn a 2D image into 3D with MoGe <a title="Github" href="https://github.com/microsoft/MoGe" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://img.shields.io/github/stars/microsoft/MoGe?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> </a> </h1>
233
+ </div>
234
+ ''')
235
+ results = gr.State(value=None)
236
+ measure_points = gr.State(value=[])
237
+
238
+ with gr.Row():
239
+ with gr.Column():
240
+ input_image = gr.Image(type="numpy", image_mode="RGB", label="Input Image")
241
+ with gr.Accordion(label="Settings", open=False):
242
+ max_size_input = gr.Number(value=800, label="Maximum Image Size", precision=0, minimum=256, maximum=2048)
243
+ resolution_level = gr.Dropdown(['Low', 'Medium', 'High', 'Ultra'], label="Inference Resolution Level", value='High')
244
+ apply_mask = gr.Checkbox(value=True, label="Apply mask")
245
+ remove_edges = gr.Checkbox(value=True, label="Remove edges")
246
+ submit_btn = gr.Button("Submit", variant='primary')
247
+
248
+ with gr.Column():
249
+ with gr.Tabs():
250
+ with gr.Tab("3D View"):
251
+ viewer_message = gr.Markdown("")
252
+ model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1.0, 1.0, 1.0, 1.0], height="60vh")
253
+ fov = gr.Markdown()
254
+ with gr.Tab("Depth"):
255
+ depth_message = gr.Markdown("")
256
+ depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format='png', interactive=False)
257
+ with gr.Tab("Normal", interactive=hasattr(model, 'normal_head')):
258
+ normal_map = gr.Image(type="numpy", label="Normal Map", format='png', interactive=False)
259
+ with gr.Tab("Measure", interactive=hasattr(model, 'scale_head')):
260
+ gr.Markdown("### Click on the image to measure the distance between two points. \n"
261
+ "**Note:** Metric scale is most reliable for typical indoor or street scenes, and may degrade for contents unfamiliar to the model (e.g., stylized or close-up images).")
262
+ measure_image = gr.Image(type="numpy", show_label=False, format='webp', interactive=False, sources=[])
263
+ measure_text = gr.Markdown("")
264
+ with gr.Tab("Download"):
265
+ files = gr.File(type='filepath', label="Output Files")
266
+
267
+ if Path('example_images').exists():
268
+ example_image_paths = sorted(list(itertools.chain(*[Path('example_images').glob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']])))
269
+ examples = gr.Examples(
270
+ examples = example_image_paths,
271
+ inputs=input_image,
272
+ label="Examples"
273
+ )
274
+
275
+ submit_btn.click(
276
+ fn=lambda: [None, None, None, None, None, "", "", ""],
277
+ outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
278
+ ).then(
279
+ fn=run,
280
+ inputs=[input_image, max_size_input, resolution_level, apply_mask, remove_edges],
281
+ outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
282
+ ).then(
283
+ fn=reset_measure,
284
+ inputs=[results],
285
+ outputs=[measure_image, measure_points, measure_text]
286
+ )
287
+
288
+ measure_image.select(
289
+ fn=measure,
290
+ inputs=[results, measure_points],
291
+ outputs=[measure_image, measure_points, measure_text]
292
+ )
293
+
294
+ demo.launch(share=share)
295
+
296
+
297
+ if __name__ == '__main__':
298
+ main()
assets/overview_simplified.png ADDED

Git LFS Details

  • SHA256: 7025a671e863bddbc22e79dc3e2eca8b7aeaf35fe93f6ef7f2b18f4fc9e093e6
  • Pointer size: 131 Bytes
  • Size of remote file: 414 kB
assets/panorama_pipeline.png ADDED

Git LFS Details

  • SHA256: ed28c5309162bddda016ca600307ecc73f7e6415f9eaaefb9f6fffadf6951aaa
  • Pointer size: 131 Bytes
  • Size of remote file: 738 kB
baselines/da_v2.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/DepthAnything/Depth-Anything-V2
2
+ import os
3
+ import sys
4
+ from typing import *
5
+ from pathlib import Path
6
+
7
+ import click
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as TF
12
+
13
+ from moge.test.baseline import MGEBaselineInterface
14
+
15
+
16
+ class Baseline(MGEBaselineInterface):
17
+ def __init__(self, repo_path: str, backbone: str, num_tokens: int, device: Union[torch.device, str]):
18
+ # Create from repo
19
+ repo_path = os.path.abspath(repo_path)
20
+ if repo_path not in sys.path:
21
+ sys.path.append(repo_path)
22
+ if not Path(repo_path).exists():
23
+ raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.')
24
+ from depth_anything_v2.dpt import DepthAnythingV2
25
+
26
+ device = torch.device(device)
27
+
28
+ # Instantiate model
29
+ model = DepthAnythingV2(encoder=backbone, features=256, out_channels=[256, 512, 1024, 1024])
30
+
31
+ # Load checkpoint
32
+ checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_{backbone}.pth')
33
+ if not os.path.exists(checkpoint_path):
34
+ raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.')
35
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
36
+ model.load_state_dict(checkpoint)
37
+
38
+ model.to(device).eval()
39
+ self.model = model
40
+ self.num_tokens = num_tokens
41
+ self.device = device
42
+
43
+ @click.command()
44
+ @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.')
45
+ @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Encoder architecture.')
46
+ @click.option('--num_tokens', type=int, default=None, help='Number of tokens to use for the input image.')
47
+ @click.option('--device', type=str, default='cuda', help='Device to use for inference.')
48
+ @staticmethod
49
+ def load(repo_path: str, backbone, num_tokens: int, device: torch.device = 'cuda'):
50
+ return Baseline(repo_path, backbone, num_tokens, device)
51
+
52
+ @torch.inference_mode()
53
+ def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
54
+ original_height, original_width = image.shape[-2:]
55
+
56
+ assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input"
57
+
58
+ if image.ndim == 3:
59
+ image = image.unsqueeze(0)
60
+ omit_batch_dim = True
61
+ else:
62
+ omit_batch_dim = False
63
+
64
+ if self.num_tokens is None:
65
+ resize_factor = 518 / min(original_height, original_width)
66
+ expected_width = round(original_width * resize_factor / 14) * 14
67
+ expected_height = round(original_height * resize_factor / 14) * 14
68
+ else:
69
+ aspect_ratio = original_width / original_height
70
+ tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5)
71
+ tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5)
72
+ expected_width = tokens_cols * 14
73
+ expected_height = tokens_rows * 14
74
+ image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
75
+
76
+ image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
77
+
78
+ disparity = self.model(image)
79
+
80
+ disparity = F.interpolate(disparity[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0]
81
+
82
+ if omit_batch_dim:
83
+ disparity = disparity.squeeze(0)
84
+
85
+ return {
86
+ 'disparity_affine_invariant': disparity
87
+ }
88
+
baselines/da_v2_metric.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference https://github.com/DepthAnything/Depth-Anything-V2/metric_depth
2
+ import os
3
+ import sys
4
+ from typing import *
5
+ from pathlib import Path
6
+
7
+ import click
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as TF
12
+ import cv2
13
+
14
+ from moge.test.baseline import MGEBaselineInterface
15
+
16
+
17
+ class Baseline(MGEBaselineInterface):
18
+
19
+ def __init__(self, repo_path: str, backbone: str, domain: str, num_tokens: int, device: str):
20
+ device = torch.device(device)
21
+ repo_path = os.path.abspath(repo_path)
22
+ if not Path(repo_path).exists():
23
+ raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.')
24
+ sys.path.append(os.path.join(repo_path, 'metric_depth'))
25
+ from depth_anything_v2.dpt import DepthAnythingV2
26
+
27
+ model_configs = {
28
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
29
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
30
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
31
+ }
32
+
33
+ if domain == 'indoor':
34
+ dataset = 'hypersim'
35
+ max_depth = 20
36
+ elif domain == 'outdoor':
37
+ dataset = 'vkitti'
38
+ max_depth = 80
39
+ else:
40
+ raise ValueError(f"Invalid domain: {domain}")
41
+
42
+ model = DepthAnythingV2(**model_configs[backbone], max_depth=max_depth)
43
+ checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_metric_{dataset}_{backbone}.pth')
44
+ if not os.path.exists(checkpoint_path):
45
+ raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.')
46
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True))
47
+ model.eval().to(device)
48
+
49
+ self.model = model
50
+ self.num_tokens = num_tokens
51
+ self.device = device
52
+
53
+ @click.command()
54
+ @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.')
55
+ @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Backbone architecture.')
56
+ @click.option('--domain', type=click.Choice(['indoor', 'outdoor']), help='Domain of the dataset.')
57
+ @click.option('--num_tokens', type=int, default=None, help='Number of tokens for the ViT model')
58
+ @click.option('--device', type=str, default='cuda', help='Device to use for inference.')
59
+ @staticmethod
60
+ def load(repo_path: str, backbone: str, domain: str, num_tokens: int, device: str):
61
+ return Baseline(repo_path, backbone, domain, num_tokens, device)
62
+
63
+ @torch.inference_mode()
64
+ def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
65
+ original_height, original_width = image.shape[-2:]
66
+
67
+ assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input"
68
+
69
+ if image.ndim == 3:
70
+ image = image.unsqueeze(0)
71
+ omit_batch_dim = True
72
+ else:
73
+ omit_batch_dim = False
74
+
75
+ if self.num_tokens is None:
76
+ resize_factor = 518 / min(original_height, original_width)
77
+ expected_width = round(original_width * resize_factor / 14) * 14
78
+ expected_height = round(original_height * resize_factor / 14) * 14
79
+ else:
80
+ aspect_ratio = original_width / original_height
81
+ tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5)
82
+ tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5)
83
+ expected_width = tokens_cols * 14
84
+ expected_height = tokens_rows * 14
85
+ image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
86
+
87
+ image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
88
+
89
+ depth = self.model(image)
90
+
91
+ depth = F.interpolate(depth[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0]
92
+
93
+ if omit_batch_dim:
94
+ depth = depth.squeeze(0)
95
+
96
+ return {
97
+ 'depth_metric': depth
98
+ }
99
+
baselines/metric3d_v2.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/YvanYin/Metric3D
2
+ import os
3
+ import sys
4
+ from typing import *
5
+
6
+ import click
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import cv2
10
+
11
+ from moge.test.baseline import MGEBaselineInterface
12
+
13
+
14
+ class Baseline(MGEBaselineInterface):
15
+ def __init__(self, backbone: Literal['vits', 'vitl', 'vitg'], device):
16
+ backbone_map = {
17
+ 'vits': 'metric3d_vit_small',
18
+ 'vitl': 'metric3d_vit_large',
19
+ 'vitg': 'metric3d_vit_giant2'
20
+ }
21
+
22
+ device = torch.device(device)
23
+ model = torch.hub.load('yvanyin/metric3d', backbone_map[backbone], pretrain=True)
24
+ model.to(device).eval()
25
+
26
+ self.model = model
27
+ self.device = device
28
+
29
+ @click.command()
30
+ @click.option('--backbone', type=click.Choice(['vits', 'vitl', 'vitg']), default='vitl', help='Encoder architecture.')
31
+ @click.option('--device', type=str, default='cuda', help='Device to use.')
32
+ @staticmethod
33
+ def load(backbone: str = 'vitl', device: torch.device = 'cuda'):
34
+ return Baseline(backbone, device)
35
+
36
+ @torch.inference_mode()
37
+ def inference_one_image(self, image: torch.Tensor, intrinsics: torch.Tensor = None):
38
+ # Reference: https://github.com/YvanYin/Metric3D/blob/main/mono/utils/do_test.py
39
+
40
+ # rgb_origin: RGB, 0-255, uint8
41
+ rgb_origin = image.cpu().numpy().transpose((1, 2, 0)) * 255
42
+
43
+ # keep ratio resize
44
+ input_size = (616, 1064) # for vit model
45
+ h, w = rgb_origin.shape[:2]
46
+ scale = min(input_size[0] / h, input_size[1] / w)
47
+ rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
48
+ if intrinsics is not None:
49
+ focal = intrinsics[0, 0] * int(w * scale)
50
+
51
+ # padding to input_size
52
+ padding = [123.675, 116.28, 103.53]
53
+ h, w = rgb.shape[:2]
54
+ pad_h = input_size[0] - h
55
+ pad_w = input_size[1] - w
56
+ pad_h_half = pad_h // 2
57
+ pad_w_half = pad_w // 2
58
+ rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=padding)
59
+ pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
60
+
61
+ # normalize rgb
62
+ mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
63
+ std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]
64
+ rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
65
+ rgb = torch.div((rgb - mean), std)
66
+ rgb = rgb[None, :, :, :].cuda()
67
+
68
+ # inference
69
+ pred_depth, confidence, output_dict = self.model.inference({'input': rgb})
70
+
71
+ # un pad
72
+ pred_depth = pred_depth.squeeze()
73
+ pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]
74
+ pred_depth = pred_depth.clamp_min(0.5) # clamp to 0.5m, since metric3d could yield very small depth values, resulting in crashed the scale shift alignment.
75
+
76
+ # upsample to original size
77
+ pred_depth = F.interpolate(pred_depth[None, None, :, :], image.shape[-2:], mode='bilinear').squeeze()
78
+
79
+ if intrinsics is not None:
80
+ # de-canonical transform
81
+ canonical_to_real_scale = focal / 1000.0 # 1000.0 is the focal length of canonical camera
82
+ pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric
83
+ pred_depth = torch.clamp(pred_depth, 0, 300)
84
+
85
+ pred_normal, normal_confidence = output_dict['prediction_normal'].split([3, 1], dim=1) # see https://arxiv.org/abs/2109.09881 for details
86
+
87
+ # un pad and resize to some size if needed
88
+ pred_normal = pred_normal.squeeze(0)
89
+ pred_normal = pred_normal[:, pad_info[0] : pred_normal.shape[1] - pad_info[1], pad_info[2] : pred_normal.shape[2] - pad_info[3]]
90
+
91
+ # you can now do anything with the normal
92
+ pred_normal = F.interpolate(pred_normal[None, :, :, :], image.shape[-2:], mode='bilinear').squeeze(0)
93
+ pred_normal = F.normalize(pred_normal, p=2, dim=0)
94
+
95
+ return pred_depth, pred_normal.permute(1, 2, 0)
96
+
97
+ @torch.inference_mode()
98
+ def infer(self, image: torch.Tensor, intrinsics: torch.Tensor = None):
99
+ # image: (B, H, W, 3) or (H, W, 3)
100
+ if image.ndim == 3:
101
+ pred_depth, pred_normal = self.inference_one_image(image, intrinsics)
102
+ else:
103
+ for i in range(image.shape[0]):
104
+ pred_depth_i, pred_normal_i = self.inference_one_image(image[i], intrinsics[i] if intrinsics is not None else None)
105
+ pred_depth.append(pred_depth_i)
106
+ pred_normal.append(pred_normal_i)
107
+ pred_depth = torch.stack(pred_depth, dim=0)
108
+ pred_normal = torch.stack(pred_normal, dim=0)
109
+
110
+ if intrinsics is not None:
111
+ return {
112
+ "depth_metric": pred_depth,
113
+ }
114
+ else:
115
+ return {
116
+ "depth_scale_invariant": pred_depth,
117
+ }
baselines/moge.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import *
4
+ import importlib
5
+
6
+ import click
7
+ import torch
8
+ import utils3d
9
+
10
+ from moge.test.baseline import MGEBaselineInterface
11
+
12
+
13
+ class Baseline(MGEBaselineInterface):
14
+
15
+ def __init__(self, num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'):
16
+ super().__init__()
17
+ from moge.model import import_model_class_by_version
18
+ MoGeModel = import_model_class_by_version(version)
19
+ self.version = version
20
+
21
+ self.model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval()
22
+
23
+ self.device = torch.device(device)
24
+ self.num_tokens = num_tokens
25
+ self.resolution_level = resolution_level
26
+ self.use_fp16 = use_fp16
27
+
28
+ @click.command()
29
+ @click.option('--num_tokens', type=int, default=None)
30
+ @click.option('--resolution_level', type=int, default=9)
31
+ @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl')
32
+ @click.option('--fp16', 'use_fp16', is_flag=True)
33
+ @click.option('--device', type=str, default='cuda:0')
34
+ @click.option('--version', type=str, default='v1')
35
+ @staticmethod
36
+ def load(num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'):
37
+ return Baseline(num_tokens, resolution_level, pretrained_model_name_or_path, use_fp16, device, version)
38
+
39
+ # Implementation for inference
40
+ @torch.inference_mode()
41
+ def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.FloatTensor] = None):
42
+ if intrinsics is not None:
43
+ fov_x, _ = utils3d.torch.intrinsics_to_fov(intrinsics)
44
+ fov_x = torch.rad2deg(fov_x)
45
+ else:
46
+ fov_x = None
47
+ output = self.model.infer(image, fov_x=fov_x, apply_mask=True, num_tokens=self.num_tokens)
48
+
49
+ if self.version == 'v1':
50
+ return {
51
+ 'points_scale_invariant': output['points'],
52
+ 'depth_scale_invariant': output['depth'],
53
+ 'intrinsics': output['intrinsics'],
54
+ }
55
+ else:
56
+ return {
57
+ 'points_metric': output['points'],
58
+ 'depth_metric': output['depth'],
59
+ 'intrinsics': output['intrinsics'],
60
+ }
61
+
62
+ @torch.inference_mode()
63
+ def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: torch.FloatTensor = None):
64
+ if intrinsics is not None:
65
+ fov_x, _ = utils3d.torch.intrinsics_to_fov(intrinsics)
66
+ fov_x = torch.rad2deg(fov_x)
67
+ else:
68
+ fov_x = None
69
+ output = self.model.infer(image, fov_x=fov_x, apply_mask=False, num_tokens=self.num_tokens, use_fp16=self.use_fp16)
70
+
71
+ if self.version == 'v1':
72
+ return {
73
+ 'points_scale_invariant': output['points'],
74
+ 'depth_scale_invariant': output['depth'],
75
+ 'intrinsics': output['intrinsics'],
76
+ }
77
+ else:
78
+ return {
79
+ 'points_metric': output['points'],
80
+ 'depth_metric': output['depth'],
81
+ 'intrinsics': output['intrinsics'],
82
+ }
83
+
configs/eval/all_benchmarks.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "NYUv2": {
3
+ "path": "data/eval/NYUv2",
4
+ "width": 640,
5
+ "height": 480,
6
+ "split": ".index.txt",
7
+ "depth_unit": 1.0
8
+ },
9
+ "KITTI": {
10
+ "path": "data/eval/KITTI",
11
+ "width": 750,
12
+ "height": 375,
13
+ "split": ".index.txt",
14
+ "depth_unit": 1
15
+ },
16
+ "ETH3D": {
17
+ "path": "data/eval/ETH3D",
18
+ "width": 2048,
19
+ "height": 1365,
20
+ "split": ".index.txt",
21
+ "include_segmentation": true,
22
+ "depth_unit": 1
23
+ },
24
+ "iBims-1": {
25
+ "path": "data/eval/iBims-1",
26
+ "width": 640,
27
+ "height": 480,
28
+ "split": ".index.txt",
29
+ "has_sharp_boundary": true,
30
+ "include_segmentation": true,
31
+ "depth_unit": 1.0
32
+ },
33
+ "GSO": {
34
+ "path": "data/eval/GSO",
35
+ "width": 512,
36
+ "height": 512,
37
+ "split": ".index.txt"
38
+ },
39
+ "Sintel": {
40
+ "path": "data/eval/Sintel",
41
+ "width": 872,
42
+ "height": 436,
43
+ "split": ".index.txt",
44
+ "has_sharp_boundary": true,
45
+ "include_segmentation": true
46
+ },
47
+ "DDAD": {
48
+ "path": "data/eval/DDAD",
49
+ "width": 1400,
50
+ "height": 700,
51
+ "include_segmentation": true,
52
+ "split": ".index.txt",
53
+ "depth_unit": 1.0
54
+ },
55
+ "DIODE": {
56
+ "path": "data/eval/DIODE",
57
+ "width": 1024,
58
+ "height": 768,
59
+ "split": ".index.txt",
60
+ "include_segmentation": true,
61
+ "depth_unit": 1.0
62
+ },
63
+ "Spring": {
64
+ "path": "data/eval/Spring",
65
+ "width": 1920,
66
+ "height": 1080,
67
+ "split": ".index.txt",
68
+ "has_sharp_boundary": true
69
+ },
70
+ "HAMMER": {
71
+ "path": "data/eval/HAMMER",
72
+ "width": 1664,
73
+ "height": 832,
74
+ "split": ".index.txt",
75
+ "depth_unit": 1,
76
+ "has_sharp_boundary": true
77
+ }
78
+ }
configs/eval/benchmarks/ddad.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "DDAD": {
3
+ "path": "data/eval/DDAD",
4
+ "width": 1400,
5
+ "height": 700,
6
+ "include_segmentation": true,
7
+ "split": ".index.txt"
8
+ }
9
+ }
configs/eval/benchmarks/diode.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "DIODE": {
3
+ "path": "data/eval/DIODE",
4
+ "width": 1024,
5
+ "height": 768,
6
+ "split": ".index.txt",
7
+ "include_segmentation": true
8
+ }
9
+ }
configs/eval/benchmarks/eth3d.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ETH3D": {
3
+ "path": "data/eval/ETH3D",
4
+ "width": 2048,
5
+ "height": 1365,
6
+ "split": ".index.txt",
7
+ "include_segmentation": true,
8
+ "depth_unit": 1
9
+ }
10
+ }
configs/eval/benchmarks/gso.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "GSO": {
3
+ "path": "data/eval/GSO",
4
+ "width": 512,
5
+ "height": 512,
6
+ "split": ".index.txt"
7
+ }
8
+ }
configs/eval/benchmarks/hammer.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "HAMMER": {
3
+ "path": "data/eval/HAMMER",
4
+ "width": 1664,
5
+ "height": 832,
6
+ "split": ".index.txt",
7
+ "depth_unit": 1,
8
+ "has_sharp_boundary": true
9
+ }
10
+ }
configs/eval/benchmarks/ibims-1.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "iBims-1": {
3
+ "path": "data/eval/iBims-1",
4
+ "width": 640,
5
+ "height": 480,
6
+ "split": ".index.txt",
7
+ "include_segmentation": true,
8
+ "has_sharp_boundary": true
9
+ }
10
+ }
configs/eval/benchmarks/kitti.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "KITTI": {
3
+ "path": "data/eval/KITTI",
4
+ "width": 750,
5
+ "height": 375,
6
+ "split": ".index.txt",
7
+ "depth_unit": 1
8
+ }
9
+ }
configs/eval/benchmarks/nyu.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "NYUv2": {
3
+ "path": "data/eval/NYUv2",
4
+ "width": 640,
5
+ "height": 480,
6
+ "split": ".test.txt"
7
+ }
8
+ }
configs/eval/benchmarks/sintel.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Sintel": {
3
+ "path": "data/eval/Sintel",
4
+ "width": 872,
5
+ "height": 436,
6
+ "split": ".index.txt",
7
+ "include_segmentation": true,
8
+ "has_sharp_boundary": true
9
+ }
10
+ }
configs/eval/benchmarks/spring.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Spring": {
3
+ "path": "data/eval/Spring",
4
+ "width": 1920,
5
+ "height": 1080,
6
+ "split": ".test.txt",
7
+ "has_sharp_boundary": true
8
+ }
9
+ }
configs/train/v1.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": {
3
+ "aspect_ratio_range": [0.5, 2.0],
4
+ "area_range": [250000, 1000000],
5
+ "clamp_max_depth": 1000.0,
6
+ "center_augmentation": 0.5,
7
+ "fov_range_absolute": [1, 179],
8
+ "fov_range_relative": [0.01, 1.0],
9
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring"],
10
+ "datasets": [
11
+ {
12
+ "name": "TartanAir",
13
+ "path": "blobmnt/data_v3/TartanAir",
14
+ "label_type": "synthetic",
15
+ "index": ".index.txt",
16
+ "depth": "depth.png",
17
+ "weight": 4.8,
18
+ "center_augmentation": 0.25,
19
+ "fov_range_absolute": [30, 150],
20
+ "fov_range_relative": [0.5, 1.0],
21
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"]
22
+ }
23
+ ]
24
+ },
25
+ "model_version": "v1",
26
+ "model": {
27
+ "encoder": "dinov2_vitl14",
28
+ "remap_output": "exp",
29
+ "intermediate_layers": 4,
30
+ "dim_upsample": [256, 128, 64],
31
+ "dim_times_res_block_hidden": 2,
32
+ "num_res_blocks": 2,
33
+ "num_tokens_range": [1200, 2500],
34
+ "last_conv_channels": 32,
35
+ "last_conv_size": 1
36
+ },
37
+ "optimizer": {
38
+ "type": "AdamW",
39
+ "params": [
40
+ {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4},
41
+ {"params": {"include": ["*backbone.*"]}, "lr": 1e-5}
42
+ ]
43
+ },
44
+ "lr_scheduler": {
45
+ "type": "SequentialLR",
46
+ "params": {
47
+ "schedulers": [
48
+ {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}},
49
+ {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}}
50
+ ],
51
+ "milestones": [2000]
52
+ }
53
+ },
54
+ "low_resolution_training_steps": 50000,
55
+ "loss": {
56
+ "invalid": {},
57
+ "synthetic": {
58
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
59
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
60
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
61
+ "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}},
62
+ "normal": {"function": "normal_loss", "weight": 1.0},
63
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
64
+ },
65
+ "sfm": {
66
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
67
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
68
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
69
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
70
+ },
71
+ "lidar": {
72
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
73
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
74
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
75
+ }
76
+ }
77
+ }
docs/eval.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ We provide a unified evaluation script that runs baselines on multiple benchmarks. It takes a baseline model and evaluation configurations, evaluates on-the-fly, and reports results instantly in a JSON file.
4
+
5
+ ## Benchmarks
6
+
7
+ Donwload the processed datasets from [Huggingface Datasets](https://huggingface.co/datasets/Ruicheng/monocular-geometry-evaluation) and put them in the `data/eval` directory, using `huggingface-cli`:
8
+
9
+ ```bash
10
+ mkdir -p data/eval
11
+ huggingface-cli download Ruicheng/monocular-geometry-evaluation --repo-type dataset --local-dir data/eval --local-dir-use-symlinks False
12
+ ```
13
+
14
+ Then unzip the downloaded files:
15
+
16
+ ```bash
17
+ cd data/eval
18
+ unzip '*.zip'
19
+ # rm *.zip # if you don't keep the zip files
20
+ ```
21
+
22
+ ## Configuration
23
+
24
+ See [`configs/eval/all_benchmarks.json`](../configs/eval/all_benchmarks.json) for an example of evaluation configurations on all benchmarks. You can modify this file to evaluate on different benchmarks or different baselines.
25
+
26
+ ## Baseline
27
+
28
+ Some examples of baselines are provided in [`baselines/`](../baselines/). Pass the path to the baseline model python code to the `--baseline` argument of the evaluation script.
29
+
30
+ ## Run Evaluation
31
+
32
+ Run the script [`moge/scripts/eval_baseline.py`](../moge/scripts/eval_baseline.py).
33
+ For example,
34
+
35
+ ```bash
36
+ # Evaluate MoGe on the 10 benchmarks
37
+ python moge/scripts/eval_baseline.py --baseline baselines/moge.py --config configs/eval/all_benchmarks.json --output eval_output/moge.json --pretrained Ruicheng/moge-vitl --resolution_level 9
38
+
39
+ # Evaluate Depth Anything V2 on the 10 benchmarks. (NOTE: affine disparity)
40
+ python moge/scripts/eval_baseline.py --baseline baselines/da_v2.py --config configs/eval/all_benchmarks.json --output eval_output/da_v2.json
41
+ ```
42
+
43
+ The `--baselies` `--input` `--output` arguments are for the inference script. The rest arguments, e.g. `--pretrained` `--resolution_level`, are custormized for loading the baseline model.
44
+
45
+ Details of the arguments:
46
+
47
+ ```
48
+ Usage: eval_baseline.py [OPTIONS]
49
+
50
+ Evaluation script.
51
+
52
+ Options:
53
+ --baseline PATH Path to the baseline model python code.
54
+ --config PATH Path to the evaluation configurations. Defaults to
55
+ "configs/eval/all_benchmarks.json".
56
+ --output PATH Path to the output json file.
57
+ --oracle Use oracle mode for evaluation, i.e., use the GT intrinsics
58
+ input.
59
+ --dump_pred Dump predition results.
60
+ --dump_gt Dump ground truth.
61
+ --help Show this message and exit.
62
+ ```
63
+
64
+
65
+
66
+ ## Wrap a Customized Baseline
67
+
68
+ Wrap any baseline method with [`moge.test.baseline.MGEBaselineInterface`](../moge/test/baseline.py).
69
+ See [`baselines/`](../baselines/) for more examples.
70
+
71
+ It is a good idea to check the correctness of the baseline implementation by running inference on a small set of images via [`moge/scripts/infer_baselines.py`](../moge/scripts/infer_baselines.py):
72
+
73
+ ```base
74
+ python moge/scripts/infer_baselines.py --baseline baselines/moge.py --input example_images/ --output infer_outupt/moge --pretrained Ruicheng/moge-vitl --maps --ply
75
+ ```
76
+
77
+
docs/train.md ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Training
3
+
4
+ This document provides instructions for training and finetuning the MoGe model.
5
+
6
+ ## Additional Requirements
7
+
8
+ The following packages other than those listed in [`pyproject.toml`](../pyproject.toml) are required for training and finetuning the MoGe model:
9
+
10
+ ```
11
+ accelerate
12
+ sympy
13
+ mlflow
14
+ ```
15
+
16
+ ## Data preparation
17
+
18
+ ### Dataset format
19
+
20
+ Each dataset should be organized as follows:
21
+
22
+ ```
23
+ somedataset
24
+ ├── .index.txt # A list of instance paths
25
+ ├── folder1
26
+ │ ├── instance1 # Each instance is in a folder
27
+ │ │ ├── image.jpg # RGB image.
28
+ │ │ ├── depth.png # 16-bit depth. See moge/utils/io.py for details
29
+ │ │ ├── meta.json # Stores "intrinsics" as a 3x3 matrix
30
+ │ │ └── ... # Other componests such as segmentation mask, normal map etc.
31
+ ...
32
+ ```
33
+
34
+ * `.index.txt` is placed at top directory to store a list of instance paths in this dataset. The dataloader will look for instances in this list. You may also use a custom split, e.g. `.train.txt`, `.val.txt` and specify it in the configuration file.
35
+
36
+ * For depth images, it is recommended to use `read_depth()` and `write_depth()` in [`moge/utils/io.py`](../moge/utils/io.py) to read and write depth images. The depth is stored in logarithmic scale in 16-bit PNG format, offering a balanced precision, dynamic range and compression ratio compared to 16-bit and 32-bit EXR and linear depth formats. It also encodes `NaN` and `Inf` values for invalid depth values.
37
+
38
+ * The `meta.json` should be a dictionary containing the key `intrinsics`, which are **normalized** camera parameters. You may put more metadata.
39
+
40
+ * We also support reading and storing segementation masks for evaluation data (see paper evaluation of local points), which are saved in PNG format with semantic labels stored in png metadata as JSON strings. See `read_segmentation()` and `write_segmentation()` in [`moge/utils/io.py`](../moge/utils/io.py) for details.
41
+
42
+
43
+ ### Visual inspection
44
+
45
+ We provide a script to visualize the data and check the data quality. It will export the instance as a PLY file for visualization of point cloud.
46
+
47
+ ```bash
48
+ python moge/scripts/vis_data.py PATH_TO_INSTANCE --ply [-o SOMEWHERE_ELSE_TO_SAVE_VIS]
49
+ ```
50
+
51
+ ### DataLoader
52
+
53
+ Our training dataloaders is customized to handle loading data, performing perspective crop, and augmentation in a multithreading pipeline. Please refer to [`moge/train/dataloader.py`](../moge/train/dataloader.py) if you have any concern.
54
+
55
+
56
+ ## Configuration
57
+
58
+ See [`configs/train/v1.json`](../configs/train/v1.json) for an example configuration file. The configuration file defines the hyperparameters for training the MoGe model.
59
+ Here is a commented configuration for reference:
60
+
61
+ ```json
62
+ {
63
+ "data": {
64
+ "aspect_ratio_range": [0.5, 2.0], # Range of aspect ratio of sampled images
65
+ "area_range": [250000, 1000000], # Range of sampled image area in pixels
66
+ "clamp_max_depth": 1000.0, # Maximum far/near
67
+ "center_augmentation": 0.5, # Ratio of center crop augmentation
68
+ "fov_range_absolute": [1, 179], # Absolute range of FOV in degrees
69
+ "fov_range_relative": [0.01, 1.0], # Relative range of FOV to the original FOV
70
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring"], # List of image augmentation techniques
71
+ "datasets": [
72
+ {
73
+ "name": "TartanAir", # Name of the dataset. Name it as you like.
74
+ "path": "data/TartanAir", # Path to the dataset
75
+ "label_type": "synthetic", # Label type for this dataset. Losses will be applied accordingly. see "loss" config
76
+ "weight": 4.8, # Probability of sampling this dataset
77
+ "index": ".index.txt", # File name of the index file. Defaults to .index.txt
78
+ "depth": "depth.png", # File name of depth images. Defaults to depth.png
79
+ "center_augmentation": 0.25, # Below are dataset-specific hyperparameters. Overriding the global ones above.
80
+ "fov_range_absolute": [30, 150],
81
+ "fov_range_relative": [0.5, 1.0],
82
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"]
83
+ }
84
+ ]
85
+ },
86
+ "model_version": "v1", # Model version. If you have multiple model variants, you can use this to switch between them.
87
+ "model": { # Model hyperparameters. Will be passed to Model __init__() as kwargs.
88
+ "encoder": "dinov2_vitl14",
89
+ "remap_output": "exp",
90
+ "intermediate_layers": 4,
91
+ "dim_upsample": [256, 128, 64],
92
+ "dim_times_res_block_hidden": 2,
93
+ "num_res_blocks": 2,
94
+ "num_tokens_range": [1200, 2500],
95
+ "last_conv_channels": 32,
96
+ "last_conv_size": 1
97
+ },
98
+ "optimizer": { # Reflection-like optimizer configurations. See moge.train.utils.py build_optimizer() for details.
99
+ "type": "AdamW",
100
+ "params": [
101
+ {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4},
102
+ {"params": {"include": ["*backbone.*"]}, "lr": 1e-5}
103
+ ]
104
+ },
105
+ "lr_scheduler": { # Reflection-like lr_scheduler configurations. See moge.train.utils.py build_lr_scheduler() for details.
106
+ "type": "SequentialLR",
107
+ "params": {
108
+ "schedulers": [
109
+ {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}},
110
+ {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}}
111
+ ],
112
+ "milestones": [2000]
113
+ }
114
+ },
115
+ "low_resolution_training_steps": 50000, # Total number of low-resolution training steps. It makes the early stage training faster. Later stage training on varying size images will be slower.
116
+ "loss": {
117
+ "invalid": {}, # invalid instance due to runtime error when loading data
118
+ "synthetic": { # Below are loss hyperparameters
119
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
120
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
121
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
122
+ "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}},
123
+ "normal": {"function": "normal_loss", "weight": 1.0},
124
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
125
+ },
126
+ "sfm": {
127
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
128
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
129
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
130
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
131
+ },
132
+ "lidar": {
133
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
134
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
135
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
136
+ }
137
+ }
138
+ }
139
+ ```
140
+
141
+ ## Run Training
142
+
143
+ Launch the training script [`moge/scripts/train.py`](../moge/scripts/train.py). Note that we use [`accelerate`](https://github.com/huggingface/accelerate) for distributed training.
144
+
145
+ ```bash
146
+ accelerate launch \
147
+ --num_processes 8 \
148
+ moge/scripts/train.py \
149
+ --config configs/train/v1.json \
150
+ --workspace workspace/debug \
151
+ --gradient_accumulation_steps 2 \
152
+ --batch_size_forward 2 \
153
+ --checkpoint latest \
154
+ --enable_gradient_checkpointing True \
155
+ --vis_every 1000 \
156
+ --enable_mlflow True
157
+ ```
158
+
159
+
160
+ ## Finetuning
161
+
162
+ To finetune the pre-trained MoGe model, download the model checkpoint and put it in a local directory, e.g. `pretrained/moge-vitl.pt`.
163
+
164
+ > NOTE: when finetuning pretrained MoGe model, a much lower learning rate is required.
165
+ The suggested learning rate for finetuning is not greater than 1e-5 for the head and 1e-6 for the backbone.
166
+ And the batch size is recommended to be 32 at least.
167
+ The settings in default configuration are not optimal for specific datasets and may require further tuning.
168
+
169
+ ```bash
170
+ accelerate launch \
171
+ --num_processes 8 \
172
+ moge/scripts/train.py \
173
+ --config configs/train/v1.json \
174
+ --workspace workspace/debug \
175
+ --gradient_accumulation_steps 2 \
176
+ --batch_size_forward 2 \
177
+ --checkpoint pretrained/moge-vitl.pt \
178
+ --enable_gradient_checkpointing True \
179
+ --vis_every 1000 \
180
+ --enable_mlflow True
181
+ ```
example_images/01_HouseIndoor.jpg ADDED

Git LFS Details

  • SHA256: 3eb519bc68d4262af0c68166ca69e786cac5f6656a1083f4c585c4a94005c859
  • Pointer size: 131 Bytes
  • Size of remote file: 322 kB
example_images/02_Office.jpg ADDED

Git LFS Details

  • SHA256: 28767640002f93b703b24a34a6d75ca24b1ef093a19f52ef0f9d3b074ef68c61
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
example_images/03_Traffic.jpg ADDED

Git LFS Details

  • SHA256: 4fa8b46849dd3de5b3b0a141d6aafe98e190f578ccec0c9dacc440cd8434db11
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
example_images/04_BunnyCake.jpg ADDED

Git LFS Details

  • SHA256: 7ddd187d91ebc2cf626bc51a26e1fc71d478237ce348732ae547f83655f05260
  • Pointer size: 130 Bytes
  • Size of remote file: 69.1 kB
example_images/05_Mountain.jpg ADDED

Git LFS Details

  • SHA256: 670d322f6588713f7d9c7349091de0aacb2a5b0b37c7b7433995e110fb2bcfbc
  • Pointer size: 131 Bytes
  • Size of remote file: 666 kB
example_images/06_MaitreyaBuddha.png ADDED

Git LFS Details

  • SHA256: 396c5fd722bf5a21b931cbb70b883d6b1d5f9bab439cc426ec2f606fc2b7872d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
example_images/07_Breads.jpg ADDED

Git LFS Details

  • SHA256: a95c2cab81412e252ee5a56a6100df31bb83de0f117607ca8476478f7f152a7b
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB
example_images/08_CatGirl.png ADDED

Git LFS Details

  • SHA256: 57fa6d587d598e7a428e8997b86d5c3a06e0e18529bfad8bab78ae03a1f5820f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
example_images/09_Restaurant.jpg ADDED

Git LFS Details

  • SHA256: b2bb7b5a1e91a174101109b0976b8ae2a4d6bb7d6eadad6569106ed102d0d5a6
  • Pointer size: 131 Bytes
  • Size of remote file: 794 kB
example_images/10_MedievalVillage.jpg ADDED

Git LFS Details

  • SHA256: 718ed1aeb1e0010194c5cf0e95371e6a29d45b84e93efbed63ff4cc60e74508b
  • Pointer size: 131 Bytes
  • Size of remote file: 465 kB
example_images/11_Room.jpg ADDED

Git LFS Details

  • SHA256: 8f34b99e89f3a57952bb88f11a6dc87e4a75423f55ad26748783c92854543cf5
  • Pointer size: 131 Bytes
  • Size of remote file: 582 kB
example_images/12_StylizedHouses.jpg ADDED

Git LFS Details

  • SHA256: 18120b27ea499ef9c921a5a02e987c687327896c7bb649a9703682737d25a6b8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
example_images/panorama/Braunschweig_Panoram.jpg ADDED

Git LFS Details

  • SHA256: abc31b78f03a0b5254f3735bc3201c28d21b6855708f971ce4b6a740dfbddcba
  • Pointer size: 131 Bytes
  • Size of remote file: 563 kB
moge/__init__.py ADDED
File without changes
moge/model/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import *
3
+
4
+ if TYPE_CHECKING:
5
+ from .v1 import MoGeModel as MoGeModelV1
6
+ from .v2 import MoGeModel as MoGeModelV2
7
+
8
+
9
+ def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
10
+ assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
11
+
12
+ try:
13
+ module = importlib.import_module(f'.{version}', __package__)
14
+ except ModuleNotFoundError:
15
+ raise ValueError(f'Model version "{version}" not found.')
16
+
17
+ cls = getattr(module, 'MoGeModel')
18
+ return cls
moge/model/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
moge/model/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
moge/model/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
moge/model/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
moge/model/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
moge/model/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ # warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ # warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ # warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x