File size: 15,712 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
"""sync."""

import atexit
import datetime
import fnmatch
import os
import queue
import sys
import tempfile
import threading
import time
from typing import List, Optional
from urllib.parse import quote as url_quote

import wandb
from wandb.proto import wandb_internal_pb2  # type: ignore
from wandb.sdk.interface.interface_queue import InterfaceQueue
from wandb.sdk.internal import context, datastore, handler, sender, tb_watcher
from wandb.sdk.internal.settings_static import SettingsStatic
from wandb.sdk.lib import filesystem
from wandb.util import check_and_warn_old

WANDB_SUFFIX = ".wandb"
SYNCED_SUFFIX = ".synced"
TFEVENT_SUBSTRING = ".tfevents."


class _LocalRun:
    def __init__(self, path, synced=None):
        self.path = path
        self.synced = synced
        self.offline = os.path.basename(path).startswith("offline-")
        self.datetime = datetime.datetime.strptime(
            os.path.basename(path).split("run-")[1].split("-")[0], "%Y%m%d_%H%M%S"
        )

    def __str__(self):
        return self.path


class SyncThread(threading.Thread):
    def __init__(
        self,
        sync_list,
        project=None,
        entity=None,
        run_id=None,
        job_type=None,
        view=None,
        verbose=None,
        mark_synced=None,
        app_url=None,
        sync_tensorboard=None,
        log_path=None,
        append=None,
        skip_console=None,
    ):
        threading.Thread.__init__(self)
        # mark this process as internal
        wandb._set_internal_process(disable=True)
        self._sync_list = sync_list
        self._project = project
        self._entity = entity
        self._run_id = run_id
        self._job_type = job_type
        self._view = view
        self._verbose = verbose
        self._mark_synced = mark_synced
        self._app_url = app_url
        self._sync_tensorboard = sync_tensorboard
        self._log_path = log_path
        self._append = append
        self._skip_console = skip_console

        self._tmp_dir = tempfile.TemporaryDirectory()
        atexit.register(self._tmp_dir.cleanup)

    def _parse_pb(self, data, exit_pb=None):
        pb = wandb_internal_pb2.Record()
        pb.ParseFromString(data)
        record_type = pb.WhichOneof("record_type")
        if self._view:
            if self._verbose:
                print("Record:", pb)  # noqa: T201
            else:
                print("Record:", record_type)  # noqa: T201
            return pb, exit_pb, True
        if record_type == "run":
            if self._run_id:
                pb.run.run_id = self._run_id
            if self._project:
                pb.run.project = self._project
            if self._entity:
                pb.run.entity = self._entity
            if self._job_type:
                pb.run.job_type = self._job_type
            pb.control.req_resp = True
        elif record_type in ("output", "output_raw") and self._skip_console:
            return pb, exit_pb, True
        elif record_type == "exit":
            exit_pb = pb
            return pb, exit_pb, True
        elif record_type == "final":
            assert exit_pb, "final seen without exit"
            pb = exit_pb
            exit_pb = None
        return pb, exit_pb, False

    def _find_tfevent_files(self, sync_item):
        tb_event_files = 0
        tb_logdirs = []
        tb_root = None
        if self._sync_tensorboard:
            if os.path.isdir(sync_item):
                files = []
                for dirpath, _, _files in os.walk(sync_item):
                    for f in _files:
                        if TFEVENT_SUBSTRING in f:
                            files.append(os.path.join(dirpath, f))
                for tfevent in files:
                    tb_event_files += 1
                    tb_dir = os.path.dirname(os.path.abspath(tfevent))
                    if tb_dir not in tb_logdirs:
                        tb_logdirs.append(tb_dir)
                if len(tb_logdirs) > 0:
                    tb_root = os.path.dirname(os.path.commonprefix(tb_logdirs))

            elif TFEVENT_SUBSTRING in sync_item:
                tb_root = os.path.dirname(os.path.abspath(sync_item))
                tb_logdirs.append(tb_root)
                tb_event_files = 1
        return tb_event_files, tb_logdirs, tb_root

    def _setup_tensorboard(self, tb_root, tb_logdirs, tb_event_files, sync_item):
        """Return true if this sync item can be synced as tensorboard."""
        if tb_root is not None:
            if tb_event_files > 0 and sync_item.endswith(WANDB_SUFFIX):
                wandb.termwarn("Found .wandb file, not streaming tensorboard metrics.")
            else:
                print(f"Found {tb_event_files} tfevent files in {tb_root}")  # noqa: T201
                if len(tb_logdirs) > 3:
                    wandb.termwarn(
                        f"Found {len(tb_logdirs)} directories containing tfevent files. "
                        "If these represent multiple experiments, sync them "
                        "individually or pass a list of paths."
                    )
                return True
        return False

    def _send_tensorboard(self, tb_root, tb_logdirs, send_manager):
        if self._entity is None:
            viewer, _ = send_manager._api.viewer_server_info()
            self._entity = viewer.get("entity")
        proto_run = wandb_internal_pb2.RunRecord()
        proto_run.run_id = self._run_id or wandb.util.generate_id()
        proto_run.project = self._project or wandb.util.auto_project_name(None)
        proto_run.entity = self._entity
        proto_run.telemetry.feature.sync_tfevents = True

        url = (
            f"{self._app_url}"
            f"/{url_quote(proto_run.entity)}"
            f"/{url_quote(proto_run.project)}"
            f"/runs/{url_quote(proto_run.run_id)}"
        )
        print(f"Syncing: {url} ...")  # noqa: T201
        sys.stdout.flush()
        # using a handler here automatically handles the step
        # logic, adds summaries to the run, and handles different
        # file types (like images)... but we need to remake the send_manager
        record_q = queue.Queue()
        sender_record_q = queue.Queue()
        new_interface = InterfaceQueue(record_q)
        context_keeper = context.ContextKeeper()
        send_manager = sender.SendManager(
            settings=send_manager._settings,
            record_q=sender_record_q,
            result_q=queue.Queue(),
            interface=new_interface,
            context_keeper=context_keeper,
        )
        record = send_manager._interface._make_record(run=proto_run)
        settings = wandb.Settings(
            root_dir=self._tmp_dir.name,
            run_id=proto_run.run_id,
            x_start_time=time.time(),
        )

        settings_static = SettingsStatic(settings.to_proto())

        handle_manager = handler.HandleManager(
            settings=settings_static,
            record_q=record_q,
            result_q=None,
            stopped=False,
            writer_q=sender_record_q,
            interface=new_interface,
            context_keeper=context_keeper,
        )

        filesystem.mkdir_exists_ok(settings.files_dir)
        send_manager.send_run(record, file_dir=settings.files_dir)
        watcher = tb_watcher.TBWatcher(settings, proto_run, new_interface, True)

        for tb in tb_logdirs:
            watcher.add(tb, True, tb_root)
            sys.stdout.flush()
        watcher.finish()

        # send all of our records like a boss
        progress_step = 0
        spinner_states = ["-", "\\", "|", "/"]
        line = " Uploading data to wandb\r"
        while len(handle_manager) > 0:
            data = next(handle_manager)
            handle_manager.handle(data)
            while len(send_manager) > 0:
                data = next(send_manager)
                send_manager.send(data)

            print_line = spinner_states[progress_step % 4] + line
            wandb.termlog(print_line, newline=False, prefix=True)
            progress_step += 1

        # finish sending any data
        while len(send_manager) > 0:
            data = next(send_manager)
            send_manager.send(data)
        sys.stdout.flush()
        handle_manager.finish()
        send_manager.finish()

    def _robust_scan(self, ds):
        """Attempt to scan data, handling incomplete files."""
        try:
            return ds.scan_data()
        except AssertionError as e:
            if ds.in_last_block():
                wandb.termwarn(
                    f".wandb file is incomplete ({e}), be sure to sync this run again once it's finished"
                )
                return None
            else:
                raise

    def run(self):
        if self._log_path is not None:
            print(f"Find logs at: {self._log_path}")  # noqa: T201
        for sync_item in self._sync_list:
            tb_event_files, tb_logdirs, tb_root = self._find_tfevent_files(sync_item)
            if os.path.isdir(sync_item):
                files = os.listdir(sync_item)
                filtered_files = list(filter(lambda f: f.endswith(WANDB_SUFFIX), files))
                if tb_root is None and (
                    check_and_warn_old(files) or len(filtered_files) != 1
                ):
                    print(f"Skipping directory: {sync_item}")  # noqa: T201
                    continue
                if len(filtered_files) > 0:
                    sync_item = os.path.join(sync_item, filtered_files[0])
            sync_tb = self._setup_tensorboard(
                tb_root, tb_logdirs, tb_event_files, sync_item
            )
            # If we're syncing tensorboard, let's use a tmp dir for images etc.
            root_dir = self._tmp_dir.name if sync_tb else os.path.dirname(sync_item)

            # When appending we are allowing a possible resume, ie the run
            # does not have to exist already
            resume = "allow" if self._append else None

            sm = sender.SendManager.setup(root_dir, resume=resume)
            if sync_tb:
                self._send_tensorboard(tb_root, tb_logdirs, sm)
                continue

            ds = datastore.DataStore()
            try:
                ds.open_for_scan(sync_item)
            except AssertionError as e:
                print(f".wandb file is empty ({e}), skipping: {sync_item}")  # noqa: T201
                continue

            # save exit for final send
            exit_pb = None
            finished = False
            shown = False
            while True:
                data = self._robust_scan(ds)
                if data is None:
                    break
                pb, exit_pb, cont = self._parse_pb(data, exit_pb)
                if exit_pb is not None:
                    finished = True
                if cont:
                    continue
                sm.send(pb)
                # send any records that were added in previous send
                while not sm._record_q.empty():
                    data = sm._record_q.get(block=True)
                    sm.send(data)

                if pb.control.req_resp:
                    result = sm._result_q.get(block=True)
                    result_type = result.WhichOneof("result_type")
                    if not shown and result_type == "run_result":
                        r = result.run_result.run
                        # TODO(jhr): hardcode until we have settings in sync
                        url = (
                            f"{self._app_url}"
                            f"/{url_quote(r.entity)}"
                            f"/{url_quote(r.project)}"
                            f"/runs/{url_quote(r.run_id)}"
                        )
                        print(f"Syncing: {url} ... ", end="")  # noqa: T201
                        sys.stdout.flush()
                        shown = True
            sm.finish()
            # Only mark synced if the run actually finished
            if self._mark_synced and not self._view and finished:
                synced_file = f"{sync_item}{SYNCED_SUFFIX}"
                with open(synced_file, "w"):
                    pass
            print("done.")  # noqa: T201


class SyncManager:
    def __init__(
        self,
        project=None,
        entity=None,
        run_id=None,
        job_type=None,
        mark_synced=None,
        app_url=None,
        view=None,
        verbose=None,
        sync_tensorboard=None,
        log_path=None,
        append=None,
        skip_console=None,
    ):
        self._sync_list = []
        self._thread = None
        self._project = project
        self._entity = entity
        self._run_id = run_id
        self._job_type = job_type
        self._mark_synced = mark_synced
        self._app_url = app_url
        self._view = view
        self._verbose = verbose
        self._sync_tensorboard = sync_tensorboard
        self._log_path = log_path
        self._append = append
        self._skip_console = skip_console

    def status(self):
        pass

    def add(self, p):
        self._sync_list.append(os.path.abspath(str(p)))

    def start(self):
        # create a thread for each file?
        self._thread = SyncThread(
            sync_list=self._sync_list,
            project=self._project,
            entity=self._entity,
            run_id=self._run_id,
            job_type=self._job_type,
            view=self._view,
            verbose=self._verbose,
            mark_synced=self._mark_synced,
            app_url=self._app_url,
            sync_tensorboard=self._sync_tensorboard,
            log_path=self._log_path,
            append=self._append,
            skip_console=self._skip_console,
        )
        self._thread.start()

    def is_done(self):
        return not self._thread.is_alive()

    def poll(self):
        time.sleep(1)
        return False


def get_runs(
    include_offline: bool = True,
    include_online: bool = True,
    include_synced: bool = False,
    include_unsynced: bool = True,
    exclude_globs: Optional[List[str]] = None,
    include_globs: Optional[List[str]] = None,
):
    # TODO(jhr): grab dir info from settings
    base = ".wandb" if os.path.exists(".wandb") else "wandb"
    if not os.path.exists(base):
        return ()
    all_dirs = os.listdir(base)
    dirs = []
    if include_offline:
        dirs += filter(lambda _d: _d.startswith("offline-run-"), all_dirs)
    if include_online:
        dirs += filter(lambda _d: _d.startswith("run-"), all_dirs)
    # find run file in each dir
    fnames = []
    dirs.sort()
    for d in dirs:
        paths = os.listdir(os.path.join(base, d))
        if exclude_globs:
            paths = set(paths)
            for g in exclude_globs:
                paths = paths - set(fnmatch.filter(paths, g))
            paths = list(paths)
        if include_globs:
            new_paths = set()
            for g in include_globs:
                new_paths = new_paths.union(fnmatch.filter(paths, g))
            paths = list(new_paths)
        for f in paths:
            if f.endswith(WANDB_SUFFIX):
                fnames.append(os.path.join(base, d, f))
    filtered = []
    for f in fnames:
        dname = os.path.dirname(f)
        # TODO(frz): online runs are assumed to be synced, verify from binary log.
        if os.path.exists(f"{f}{SYNCED_SUFFIX}") or os.path.basename(dname).startswith(
            "run-"
        ):
            if include_synced:
                filtered.append(_LocalRun(dname, True))
        else:
            if include_unsynced:
                filtered.append(_LocalRun(dname, False))
    return tuple(filtered)


def get_run_from_path(path):
    return _LocalRun(path)