File size: 1,783 Bytes
f880d1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import annotations

from pathlib import Path

from mlip_arena.models import REGISTRY, MLIPEnum
from mlip_arena.tasks.stability.flow import compression, heating

if __name__ == "__main__":
    from dask.distributed import Client
    from dask_jobqueue import SLURMCluster
    from prefect_dask import DaskTaskRunner

    nodes_per_alloc = 1
    gpus_per_alloc = 4

    cluster_kwargs = dict(
        cores=1,
        memory="64 GB",
        processes=1,
        shebang="#!/bin/bash",
        account="matgen",
        walltime="04:00:00",
        job_mem="0",
        job_script_prologue=[
            "source ~/.bashrc",
            "module load python",
            "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena",
        ],
        job_directives_skip=["-n", "--cpus-per-task", "-J"],
        job_extra_directives=[
            "-J arena-stability",
            "-q preempt",
            "--time-min=00:30:00",
            "--comment=12:00:00",
            f"-N {nodes_per_alloc}",
            "-C gpu",
            f"-G {gpus_per_alloc}",
        ],
    )

    cluster = SLURMCluster(**cluster_kwargs)
    print(cluster.job_script())
    cluster.adapt(minimum_jobs=10, maximum_jobs=50)
    client = Client(cluster)

    for model in MLIPEnum:
        run_dir = Path(__file__).parent / f"{REGISTRY[model.name]['family']}"

        heating.with_options(
            task_runner=DaskTaskRunner(address=client.scheduler.address),
            log_prints=True,
        )(model, run_dir)

    for model in MLIPEnum:
        run_dir = Path(__file__).parent / f"{REGISTRY[model.name]['family']}"

        compression.with_options(
            task_runner=DaskTaskRunner(address=client.scheduler.address),
            log_prints=True,
        )(model, run_dir)