|
Metadata-Version: 2.4 |
|
Name: jaxtyping |
|
Version: 0.3.2 |
|
Summary: Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. |
|
Project-URL: repository, https://github.com/google/jaxtyping |
|
Author-email: Patrick Kidger <contact@kidger.site> |
|
License: MIT License |
|
|
|
Copyright (c) 2022 Google LLC |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
|
|
|
|
|
|
|
|
|
|
Sections of the code were modified from https://github.com/agronholm/typeguard |
|
under the terms of the MIT license, reproduced below. |
|
|
|
|
|
MIT License |
|
|
|
Copyright (c) Alex Grönholm |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
License-File: LICENSE |
|
Keywords: deep-learning,equinox,jax,neural-networks,typing |
|
Classifier: Development Status :: 3 - Alpha |
|
Classifier: Intended Audience :: Developers |
|
Classifier: Intended Audience :: Financial and Insurance Industry |
|
Classifier: Intended Audience :: Information Technology |
|
Classifier: Intended Audience :: Science/Research |
|
Classifier: License :: OSI Approved :: MIT License |
|
Classifier: Natural Language :: English |
|
Classifier: Programming Language :: Python :: 3 |
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence |
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis |
|
Classifier: Topic :: Scientific/Engineering :: Mathematics |
|
Requires-Python: >=3.10 |
|
Requires-Dist: wadler-lindig>=0.1.3 |
|
Provides-Extra: docs |
|
Requires-Dist: hippogriffe==0.2.0; extra == 'docs' |
|
Requires-Dist: mkdocs-include-exclude-files==0.1.0; extra == 'docs' |
|
Requires-Dist: mkdocs-ipynb==0.1.0; extra == 'docs' |
|
Requires-Dist: mkdocs-material==9.6.7; extra == 'docs' |
|
Requires-Dist: mkdocs==1.6.1; extra == 'docs' |
|
Requires-Dist: mkdocstrings[python]==0.28.3; extra == 'docs' |
|
Requires-Dist: pymdown-extensions==10.14.3; extra == 'docs' |
|
Description-Content-Type: text/markdown |
|
|
|
<h1 align="center">jaxtyping</h1> |
|
|
|
Type annotations **and runtime type-checking** for: |
|
|
|
1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)* |
|
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html). |
|
|
|
**For example:** |
|
```python |
|
from jaxtyping import Array, Float, PyTree |
|
|
|
|
|
|
|
def matrix_multiply(x: Float[Array, "dim1 dim2"], |
|
y: Float[Array, "dim2 dim3"] |
|
) -> Float[Array, "dim1 dim3"]: |
|
... |
|
|
|
def accepts_pytree_of_ints(x: PyTree[int]): |
|
... |
|
|
|
def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]): |
|
... |
|
``` |
|
|
|
|
|
|
|
```bash |
|
pip install jaxtyping |
|
``` |
|
|
|
Requires Python 3.10+. |
|
|
|
JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc. |
|
|
|
The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments). |
|
|
|
|
|
|
|
Available at [https://docs.kidger.site/jaxtyping](https://docs.kidger.site/jaxtyping). |
|
|
|
|
|
|
|
**Always useful** |
|
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! |
|
|
|
**Deep learning** |
|
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. |
|
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). |
|
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). |
|
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. |
|
|
|
**Scientific computing** |
|
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. |
|
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. |
|
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers. |
|
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. |
|
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. |
|
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) |
|
|
|
**Awesome JAX** |
|
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. |
|
|