|
Metadata-Version: 2.4 |
|
Name: lightning |
|
Version: 2.5.2 |
|
Summary: The Deep Learning framework to train, deploy, and ship AI products Lightning fast. |
|
Home-page: https://github.com/Lightning-AI/lightning |
|
Download-URL: https://github.com/Lightning-AI/lightning |
|
Author: Lightning AI et al. |
|
Author-email: pytorch@lightning.ai |
|
License: Apache-2.0 |
|
Project-URL: Bug Tracker, https://github.com/Lightning-AI/pytorch-lightning/issues |
|
Project-URL: Documentation, https://lightning.ai/lightning-docs |
|
Project-URL: Source Code, https://github.com/Lightning-AI/lightning |
|
Keywords: deep learning,pytorch,AI |
|
Classifier: Environment :: Console |
|
Classifier: Natural Language :: English |
|
Classifier: Development Status :: 4 - Beta |
|
Classifier: Intended Audience :: Developers |
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence |
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis |
|
Classifier: License :: OSI Approved :: Apache Software License |
|
Classifier: Operating System :: OS Independent |
|
Classifier: Programming Language :: Python :: 3 |
|
Classifier: Programming Language :: Python :: 3.9 |
|
Classifier: Programming Language :: Python :: 3.10 |
|
Classifier: Programming Language :: Python :: 3.11 |
|
Classifier: Programming Language :: Python :: 3.12 |
|
Requires-Python: >=3.9 |
|
Description-Content-Type: text/markdown |
|
License-File: LICENSE |
|
Requires-Dist: PyYAML<8.0,>=5.4 |
|
Requires-Dist: fsspec[http]<2027.0,>=2022.5.0 |
|
Requires-Dist: lightning-utilities<2.0,>=0.10.0 |
|
Requires-Dist: packaging<27.0,>=20.0 |
|
Requires-Dist: torch<4.0,>=2.1.0 |
|
Requires-Dist: torchmetrics<3.0,>=0.7.0 |
|
Requires-Dist: tqdm<6.0,>=4.57.0 |
|
Requires-Dist: typing-extensions<6.0,>=4.4.0 |
|
Requires-Dist: pytorch-lightning |
|
Provides-Extra: pytorch-strategies |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "pytorch-strategies" |
|
Provides-Extra: pytorch-test |
|
Requires-Dist: cloudpickle<4.0,>=1.3; extra == "pytorch-test" |
|
Requires-Dist: coverage==7.9.1; extra == "pytorch-test" |
|
Requires-Dist: fastapi; extra == "pytorch-test" |
|
Requires-Dist: numpy<2.0,>=1.17.2; extra == "pytorch-test" |
|
Requires-Dist: onnx<2.0,>=1.12.0; extra == "pytorch-test" |
|
Requires-Dist: onnxruntime<2.0,>=1.12.0; extra == "pytorch-test" |
|
Requires-Dist: pandas<3.0,>2.0; extra == "pytorch-test" |
|
Requires-Dist: psutil<8.0; extra == "pytorch-test" |
|
Requires-Dist: pytest-cov==6.2.1; extra == "pytorch-test" |
|
Requires-Dist: pytest-random-order==1.1.1; extra == "pytorch-test" |
|
Requires-Dist: pytest-rerunfailures==15.1; extra == "pytorch-test" |
|
Requires-Dist: pytest-timeout==2.4.0; extra == "pytorch-test" |
|
Requires-Dist: pytest==8.4.0; extra == "pytorch-test" |
|
Requires-Dist: scikit-learn<2.0,>0.22.1; extra == "pytorch-test" |
|
Requires-Dist: tensorboard<3.0,>=2.9.1; extra == "pytorch-test" |
|
Requires-Dist: uvicorn; extra == "pytorch-test" |
|
Provides-Extra: pytorch-extra |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "pytorch-extra" |
|
Requires-Dist: hydra-core<2.0,>=1.2.0; extra == "pytorch-extra" |
|
Requires-Dist: jsonargparse[jsonnet,signatures]<5.0,>=4.39.0; extra == "pytorch-extra" |
|
Requires-Dist: matplotlib<4.0,>3.1; extra == "pytorch-extra" |
|
Requires-Dist: omegaconf<3.0,>=2.2.3; extra == "pytorch-extra" |
|
Requires-Dist: rich<15.0,>=12.3.0; extra == "pytorch-extra" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "pytorch-extra" |
|
Provides-Extra: pytorch-examples |
|
Requires-Dist: ipython[all]<9.0; extra == "pytorch-examples" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "pytorch-examples" |
|
Requires-Dist: requests<3.0; extra == "pytorch-examples" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "pytorch-examples" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "pytorch-examples" |
|
Provides-Extra: fabric-strategies |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "fabric-strategies" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "fabric-strategies" |
|
Provides-Extra: fabric-test |
|
Requires-Dist: click==8.1.8; extra == "fabric-test" |
|
Requires-Dist: coverage==7.9.1; extra == "fabric-test" |
|
Requires-Dist: numpy<2.0,>=1.17.2; extra == "fabric-test" |
|
Requires-Dist: pytest-cov==6.2.1; extra == "fabric-test" |
|
Requires-Dist: pytest-random-order==1.1.1; extra == "fabric-test" |
|
Requires-Dist: pytest-rerunfailures==15.1; extra == "fabric-test" |
|
Requires-Dist: pytest-timeout==2.4.0; extra == "fabric-test" |
|
Requires-Dist: pytest==8.4.0; extra == "fabric-test" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "fabric-test" |
|
Provides-Extra: fabric-examples |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "fabric-examples" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "fabric-examples" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "fabric-examples" |
|
Provides-Extra: fabric-all |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "fabric-all" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "fabric-all" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "fabric-all" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "fabric-all" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "fabric-all" |
|
Provides-Extra: fabric-dev |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "fabric-dev" |
|
Requires-Dist: click==8.1.8; extra == "fabric-dev" |
|
Requires-Dist: coverage==7.9.1; extra == "fabric-dev" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "fabric-dev" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "fabric-dev" |
|
Requires-Dist: numpy<2.0,>=1.17.2; extra == "fabric-dev" |
|
Requires-Dist: pytest-cov==6.2.1; extra == "fabric-dev" |
|
Requires-Dist: pytest-random-order==1.1.1; extra == "fabric-dev" |
|
Requires-Dist: pytest-rerunfailures==15.1; extra == "fabric-dev" |
|
Requires-Dist: pytest-timeout==2.4.0; extra == "fabric-dev" |
|
Requires-Dist: pytest==8.4.0; extra == "fabric-dev" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "fabric-dev" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "fabric-dev" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "fabric-dev" |
|
Provides-Extra: pytorch-all |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "pytorch-all" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "pytorch-all" |
|
Requires-Dist: hydra-core<2.0,>=1.2.0; extra == "pytorch-all" |
|
Requires-Dist: ipython[all]<9.0; extra == "pytorch-all" |
|
Requires-Dist: jsonargparse[jsonnet,signatures]<5.0,>=4.39.0; extra == "pytorch-all" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "pytorch-all" |
|
Requires-Dist: matplotlib<4.0,>3.1; extra == "pytorch-all" |
|
Requires-Dist: omegaconf<3.0,>=2.2.3; extra == "pytorch-all" |
|
Requires-Dist: requests<3.0; extra == "pytorch-all" |
|
Requires-Dist: rich<15.0,>=12.3.0; extra == "pytorch-all" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "pytorch-all" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "pytorch-all" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "pytorch-all" |
|
Provides-Extra: pytorch-dev |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "pytorch-dev" |
|
Requires-Dist: cloudpickle<4.0,>=1.3; extra == "pytorch-dev" |
|
Requires-Dist: coverage==7.9.1; extra == "pytorch-dev" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "pytorch-dev" |
|
Requires-Dist: fastapi; extra == "pytorch-dev" |
|
Requires-Dist: hydra-core<2.0,>=1.2.0; extra == "pytorch-dev" |
|
Requires-Dist: ipython[all]<9.0; extra == "pytorch-dev" |
|
Requires-Dist: jsonargparse[jsonnet,signatures]<5.0,>=4.39.0; extra == "pytorch-dev" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "pytorch-dev" |
|
Requires-Dist: matplotlib<4.0,>3.1; extra == "pytorch-dev" |
|
Requires-Dist: numpy<2.0,>=1.17.2; extra == "pytorch-dev" |
|
Requires-Dist: omegaconf<3.0,>=2.2.3; extra == "pytorch-dev" |
|
Requires-Dist: onnx<2.0,>=1.12.0; extra == "pytorch-dev" |
|
Requires-Dist: onnxruntime<2.0,>=1.12.0; extra == "pytorch-dev" |
|
Requires-Dist: pandas<3.0,>2.0; extra == "pytorch-dev" |
|
Requires-Dist: psutil<8.0; extra == "pytorch-dev" |
|
Requires-Dist: pytest-cov==6.2.1; extra == "pytorch-dev" |
|
Requires-Dist: pytest-random-order==1.1.1; extra == "pytorch-dev" |
|
Requires-Dist: pytest-rerunfailures==15.1; extra == "pytorch-dev" |
|
Requires-Dist: pytest-timeout==2.4.0; extra == "pytorch-dev" |
|
Requires-Dist: pytest==8.4.0; extra == "pytorch-dev" |
|
Requires-Dist: requests<3.0; extra == "pytorch-dev" |
|
Requires-Dist: rich<15.0,>=12.3.0; extra == "pytorch-dev" |
|
Requires-Dist: scikit-learn<2.0,>0.22.1; extra == "pytorch-dev" |
|
Requires-Dist: tensorboard<3.0,>=2.9.1; extra == "pytorch-dev" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "pytorch-dev" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "pytorch-dev" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "pytorch-dev" |
|
Requires-Dist: uvicorn; extra == "pytorch-dev" |
|
Provides-Extra: strategies |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "strategies" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "strategies" |
|
Provides-Extra: test |
|
Requires-Dist: click==8.1.8; extra == "test" |
|
Requires-Dist: cloudpickle<4.0,>=1.3; extra == "test" |
|
Requires-Dist: coverage==7.9.1; extra == "test" |
|
Requires-Dist: fastapi; extra == "test" |
|
Requires-Dist: numpy<2.0,>=1.17.2; extra == "test" |
|
Requires-Dist: onnx<2.0,>=1.12.0; extra == "test" |
|
Requires-Dist: onnxruntime<2.0,>=1.12.0; extra == "test" |
|
Requires-Dist: pandas<3.0,>2.0; extra == "test" |
|
Requires-Dist: psutil<8.0; extra == "test" |
|
Requires-Dist: pytest-cov==6.2.1; extra == "test" |
|
Requires-Dist: pytest-random-order==1.1.1; extra == "test" |
|
Requires-Dist: pytest-rerunfailures==15.1; extra == "test" |
|
Requires-Dist: pytest-timeout==2.4.0; extra == "test" |
|
Requires-Dist: pytest==8.4.0; extra == "test" |
|
Requires-Dist: scikit-learn<2.0,>0.22.1; extra == "test" |
|
Requires-Dist: tensorboard<3.0,>=2.9.1; extra == "test" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "test" |
|
Requires-Dist: uvicorn; extra == "test" |
|
Provides-Extra: extra |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "extra" |
|
Requires-Dist: hydra-core<2.0,>=1.2.0; extra == "extra" |
|
Requires-Dist: jsonargparse[jsonnet,signatures]<5.0,>=4.39.0; extra == "extra" |
|
Requires-Dist: matplotlib<4.0,>3.1; extra == "extra" |
|
Requires-Dist: omegaconf<3.0,>=2.2.3; extra == "extra" |
|
Requires-Dist: rich<15.0,>=12.3.0; extra == "extra" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "extra" |
|
Provides-Extra: examples |
|
Requires-Dist: ipython[all]<9.0; extra == "examples" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "examples" |
|
Requires-Dist: requests<3.0; extra == "examples" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "examples" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "examples" |
|
Provides-Extra: data |
|
Requires-Dist: litdata<1.0,>=0.2.0rc; extra == "data" |
|
Provides-Extra: all |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "all" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "all" |
|
Requires-Dist: hydra-core<2.0,>=1.2.0; extra == "all" |
|
Requires-Dist: ipython[all]<9.0; extra == "all" |
|
Requires-Dist: jsonargparse[jsonnet,signatures]<5.0,>=4.39.0; extra == "all" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "all" |
|
Requires-Dist: matplotlib<4.0,>3.1; extra == "all" |
|
Requires-Dist: omegaconf<3.0,>=2.2.3; extra == "all" |
|
Requires-Dist: requests<3.0; extra == "all" |
|
Requires-Dist: rich<15.0,>=12.3.0; extra == "all" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "all" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "all" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "all" |
|
Provides-Extra: dev |
|
Requires-Dist: bitsandbytes<1.0,>=0.45.2; platform_system != "Darwin" and extra == "dev" |
|
Requires-Dist: click==8.1.8; extra == "dev" |
|
Requires-Dist: cloudpickle<4.0,>=1.3; extra == "dev" |
|
Requires-Dist: coverage==7.9.1; extra == "dev" |
|
Requires-Dist: deepspeed<=0.9.3,>=0.8.2; (platform_system != "Windows" and platform_system != "Darwin") and extra == "dev" |
|
Requires-Dist: fastapi; extra == "dev" |
|
Requires-Dist: hydra-core<2.0,>=1.2.0; extra == "dev" |
|
Requires-Dist: ipython[all]<9.0; extra == "dev" |
|
Requires-Dist: jsonargparse[jsonnet,signatures]<5.0,>=4.39.0; extra == "dev" |
|
Requires-Dist: lightning-utilities<1.0,>=0.8.0; extra == "dev" |
|
Requires-Dist: matplotlib<4.0,>3.1; extra == "dev" |
|
Requires-Dist: numpy<2.0,>=1.17.2; extra == "dev" |
|
Requires-Dist: omegaconf<3.0,>=2.2.3; extra == "dev" |
|
Requires-Dist: onnx<2.0,>=1.12.0; extra == "dev" |
|
Requires-Dist: onnxruntime<2.0,>=1.12.0; extra == "dev" |
|
Requires-Dist: pandas<3.0,>2.0; extra == "dev" |
|
Requires-Dist: psutil<8.0; extra == "dev" |
|
Requires-Dist: pytest-cov==6.2.1; extra == "dev" |
|
Requires-Dist: pytest-random-order==1.1.1; extra == "dev" |
|
Requires-Dist: pytest-rerunfailures==15.1; extra == "dev" |
|
Requires-Dist: pytest-timeout==2.4.0; extra == "dev" |
|
Requires-Dist: pytest==8.4.0; extra == "dev" |
|
Requires-Dist: requests<3.0; extra == "dev" |
|
Requires-Dist: rich<15.0,>=12.3.0; extra == "dev" |
|
Requires-Dist: scikit-learn<2.0,>0.22.1; extra == "dev" |
|
Requires-Dist: tensorboard<3.0,>=2.9.1; extra == "dev" |
|
Requires-Dist: tensorboardX<3.0,>=2.2; extra == "dev" |
|
Requires-Dist: torchmetrics<2.0,>=0.10.0; extra == "dev" |
|
Requires-Dist: torchvision<1.0,>=0.16.0; extra == "dev" |
|
Requires-Dist: uvicorn; extra == "dev" |
|
Dynamic: author |
|
Dynamic: author-email |
|
Dynamic: classifier |
|
Dynamic: description |
|
Dynamic: description-content-type |
|
Dynamic: download-url |
|
Dynamic: home-page |
|
Dynamic: keywords |
|
Dynamic: license |
|
Dynamic: license-file |
|
Dynamic: project-url |
|
Dynamic: provides-extra |
|
Dynamic: requires-dist |
|
Dynamic: requires-python |
|
Dynamic: summary |
|
|
|
<div align="center"> |
|
|
|
<img alt="Lightning" src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/ptl_banner.png" width="800px" style="max-width: 100%;"> |
|
|
|
<br/> |
|
<br/> |
|
|
|
**The deep learning framework to pretrain, finetune and deploy AI models.** |
|
|
|
**NEW- Deploying models? Check out [LitServe](https://github.com/Lightning-AI/litserve), the PyTorch Lightning for model serving** |
|
|
|
______________________________________________________________________ |
|
|
|
<p align="center"> |
|
<a href="#quick-start" style="margin: 0 10px;">Quick start</a> β’ |
|
<a href="#examples">Examples</a> β’ |
|
<a href="#why-pytorch-lightning">PyTorch Lightning</a> β’ |
|
<a href="#lightning-fabric-expert-control">Fabric</a> β’ |
|
<a href="https://lightning.ai/">Lightning AI</a> β’ |
|
<a href="#community">Community</a> β’ |
|
<a href="https://pytorch-lightning.readthedocs.io/en/2.5.2">Docs</a> |
|
</p> |
|
|
|
<!-- DO NOT ADD CONDA DOWNLOADS... README CHANGES MUST BE APPROVED BY EDEN OR WILL --> |
|
|
|
[](https://pypi.org/project/pytorch-lightning/) |
|
[](https://badge.fury.io/py/pytorch-lightning) |
|
[](https://pepy.tech/project/pytorch-lightning) |
|
[](https://anaconda.org/conda-forge/lightning) |
|
[](https://codecov.io/gh/Lightning-AI/pytorch-lightning) |
|
|
|
[](https://discord.gg/VptPCZkGNa) |
|
 |
|
[](https://github.com/Lightning-AI/pytorch-lightning/blob/master/LICENSE) |
|
|
|
<!-- |
|
[](https://www.codefactor.io/repository/github/Lightning-AI/lightning) |
|
--> |
|
|
|
</div> |
|
|
|
<div align="center"> |
|
|
|
<p align="center"> |
|
|
|
|
|
|
|
<a target="_blank" href="https://lightning.ai/docs/pytorch/latest/starter/introduction.html#define-a-lightningmodule"> |
|
<img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/get-started-badge.svg" height="36px" alt="Get started"/> |
|
</a> |
|
|
|
</p> |
|
|
|
</div> |
|
|
|
|
|
|
|
# Lightning has 2 core packages |
|
|
|
[PyTorch Lightning: Train and deploy PyTorch at scale](#why-pytorch-lightning). |
|
<br/> |
|
[Lightning Fabric: Expert control](#lightning-fabric-expert-control). |
|
|
|
Lightning gives you granular control over how much abstraction you want to add over PyTorch. |
|
|
|
<div align="center"> |
|
<img src="https://pl-public-data.s3.amazonaws.com/assets_lightning/continuum.png" width="80%"> |
|
</div> |
|
|
|
|
|
|
|
# Quick start |
|
Install Lightning: |
|
|
|
```bash |
|
pip install lightning |
|
``` |
|
|
|
<!-- --> |
|
|
|
### PyTorch Lightning example |
|
Define the training workflow. Here's a toy example ([explore real examples](https://lightning.ai/lightning-ai/studios?view=public§ion=featured&query=pytorch+lightning)): |
|
|
|
```python |
|
# main.py |
|
# ! pip install torchvision |
|
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F |
|
import lightning as L |
|
|
|
# -------------------------------- |
|
# Step 1: Define a LightningModule |
|
# -------------------------------- |
|
# A LightningModule (nn.Module subclass) defines a full *system* |
|
# (ie: an LLM, diffusion model, autoencoder, or simple image classifier). |
|
|
|
|
|
class LitAutoEncoder(L.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) |
|
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) |
|
|
|
def forward(self, x): |
|
# in lightning, forward defines the prediction/inference actions |
|
embedding = self.encoder(x) |
|
return embedding |
|
|
|
def training_step(self, batch, batch_idx): |
|
# training_step defines the train loop. It is independent of forward |
|
x, _ = batch |
|
x = x.view(x.size(0), -1) |
|
z = self.encoder(x) |
|
x_hat = self.decoder(z) |
|
loss = F.mse_loss(x_hat, x) |
|
self.log("train_loss", loss) |
|
return loss |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) |
|
return optimizer |
|
|
|
|
|
# ------------------- |
|
# Step 2: Define data |
|
# ------------------- |
|
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) |
|
train, val = data.random_split(dataset, [55000, 5000]) |
|
|
|
# ------------------- |
|
# Step 3: Train |
|
# ------------------- |
|
autoencoder = LitAutoEncoder() |
|
trainer = L.Trainer() |
|
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val)) |
|
``` |
|
|
|
Run the model on your terminal |
|
|
|
```bash |
|
pip install torchvision |
|
python main.py |
|
``` |
|
|
|
|
|
|
|
|
|
# Why PyTorch Lightning? |
|
|
|
PyTorch Lightning is just organized PyTorch - Lightning disentangles PyTorch code to decouple the science from the engineering. |
|
|
|
|
|
|
|
|
|
|
|
---- |
|
|
|
### Examples |
|
Explore various types of training possible with PyTorch Lightning. Pretrain and finetune ANY kind of model to perform ANY task like classification, segmentation, summarization and more: |
|
|
|
| Task | Description | Run | |
|
|-------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------|---| |
|
| [Hello world](#hello-simple-model) | Pretrain - Hello world example | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/pytorch-lightning-hello-world"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Image classification](https://lightning.ai/lightning-ai/studios/image-classification-with-pytorch-lightning) | Finetune - ResNet-34 model to classify images of cars | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/image-classification-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Image segmentation](https://lightning.ai/lightning-ai/studios/image-segmentation-with-pytorch-lightning) | Finetune - ResNet-50 model to segment images | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/image-segmentation-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Object detection](https://lightning.ai/lightning-ai/studios/object-detection-with-pytorch-lightning) | Finetune - Faster R-CNN model to detect objects | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/object-detection-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Text classification](https://lightning.ai/lightning-ai/studios/text-classification-with-pytorch-lightning) | Finetune - text classifier (BERT model) | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/text-classification-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Text summarization](https://lightning.ai/lightning-ai/studios/text-summarization-with-pytorch-lightning) | Finetune - text summarization (Hugging Face transformer model) | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/text-summarization-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Audio generation](https://lightning.ai/lightning-ai/studios/finetune-a-personal-ai-music-generator) | Finetune - audio generator (transformer model) | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/finetune-a-personal-ai-music-generator"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [LLM finetuning](https://lightning.ai/lightning-ai/studios/finetune-an-llm-with-pytorch-lightning) | Finetune - LLM (Meta Llama 3.1 8B) | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/finetune-an-llm-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Image generation](https://lightning.ai/lightning-ai/studios/train-a-diffusion-model-with-pytorch-lightning) | Pretrain - Image generator (diffusion model) | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/train-a-diffusion-model-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Recommendation system](https://lightning.ai/lightning-ai/studios/recommendation-system-with-pytorch-lightning) | Train - recommendation system (factorization and embedding) | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/recommendation-system-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
| [Time-series forecasting](https://lightning.ai/lightning-ai/studios/time-series-forecasting-with-pytorch-lightning) | Train - Time-series forecasting with LSTM | <a target="_blank" href="https://lightning.ai/lightning-ai/studios/time-series-forecasting-with-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio"/></a> | |
|
|
|
______________________________________________________________________ |
|
|
|
## Advanced features |
|
|
|
Lightning has over [40+ advanced features](https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-flags) designed for professional AI research at scale. |
|
|
|
Here are some examples: |
|
|
|
<div align="center"> |
|
<img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/features_2.jpg" max-height="600px"> |
|
</div> |
|
|
|
<details> |
|
<summary>Train on 1000s of GPUs without code changes</summary> |
|
|
|
```python |
|
# 8 GPUs |
|
# no code changes needed |
|
trainer = Trainer(accelerator="gpu", devices=8) |
|
|
|
# 256 GPUs |
|
trainer = Trainer(accelerator="gpu", devices=8, num_nodes=32) |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>Train on other accelerators like TPUs without code changes</summary> |
|
|
|
```python |
|
# no code changes needed |
|
trainer = Trainer(accelerator="tpu", devices=8) |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>16-bit precision</summary> |
|
|
|
```python |
|
# no code changes needed |
|
trainer = Trainer(precision=16) |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>Experiment managers</summary> |
|
|
|
```python |
|
from lightning import loggers |
|
|
|
# tensorboard |
|
trainer = Trainer(logger=TensorBoardLogger("logs/")) |
|
|
|
# weights and biases |
|
trainer = Trainer(logger=loggers.WandbLogger()) |
|
|
|
# comet |
|
trainer = Trainer(logger=loggers.CometLogger()) |
|
|
|
# mlflow |
|
trainer = Trainer(logger=loggers.MLFlowLogger()) |
|
|
|
# neptune |
|
trainer = Trainer(logger=loggers.NeptuneLogger()) |
|
|
|
# ... and dozens more |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
|
|
<summary>Early Stopping</summary> |
|
|
|
```python |
|
es = EarlyStopping(monitor="val_loss") |
|
trainer = Trainer(callbacks=[es]) |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>Checkpointing</summary> |
|
|
|
```python |
|
checkpointing = ModelCheckpoint(monitor="val_loss") |
|
trainer = Trainer(callbacks=[checkpointing]) |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>Export to torchscript (JIT) (production use)</summary> |
|
|
|
```python |
|
# torchscript |
|
autoencoder = LitAutoEncoder() |
|
torch.jit.save(autoencoder.to_torchscript(), "model.pt") |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>Export to ONNX (production use)</summary> |
|
|
|
```python |
|
# onnx |
|
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile: |
|
autoencoder = LitAutoEncoder() |
|
input_sample = torch.randn((1, 64)) |
|
autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True) |
|
os.path.isfile(tmpfile.name) |
|
``` |
|
|
|
</details> |
|
|
|
______________________________________________________________________ |
|
|
|
## Advantages over unstructured PyTorch |
|
|
|
- Models become hardware agnostic |
|
- Code is clear to read because engineering code is abstracted away |
|
- Easier to reproduce |
|
- Make fewer mistakes because lightning handles the tricky engineering |
|
- Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate |
|
- Lightning has dozens of integrations with popular machine learning tools. |
|
- [Tested rigorously with every new PR](https://github.com/Lightning-AI/lightning/tree/master/tests). We test every combination of PyTorch and Python supported versions, every OS, multi GPUs and even TPUs. |
|
- Minimal running speed overhead (about 300 ms per epoch compared with pure PyTorch). |
|
|
|
______________________________________________________________________ |
|
|
|
<div align="center"> |
|
<a href="https://lightning.ai/docs/pytorch/stable/">Read the PyTorch Lightning docs</a> |
|
</div> |
|
|
|
______________________________________________________________________ |
|
|
|
|
|
|
|
|
|
# Lightning Fabric: Expert control |
|
|
|
Run on any device at any scale with expert-level control over PyTorch training loop and scaling strategy. You can even write your own Trainer. |
|
|
|
Fabric is designed for the most complex models like foundation model scaling, LLMs, diffusion, transformers, reinforcement learning, active learning. Of any size. |
|
|
|
<table> |
|
<tr> |
|
<th>What to change</th> |
|
<th>Resulting Fabric Code (copy me!)</th> |
|
</tr> |
|
<tr> |
|
<td> |
|
<sub> |
|
|
|
```diff |
|
+ import lightning as L |
|
import torch; import torchvision as tv |
|
|
|
dataset = tv.datasets.CIFAR10("data", download=True, |
|
train=True, |
|
transform=tv.transforms.ToTensor()) |
|
|
|
+ fabric = L.Fabric() |
|
+ fabric.launch() |
|
|
|
model = tv.models.resnet18() |
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.001) |
|
- device = "cuda" if torch.cuda.is_available() else "cpu" |
|
- model.to(device) |
|
+ model, optimizer = fabric.setup(model, optimizer) |
|
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) |
|
+ dataloader = fabric.setup_dataloaders(dataloader) |
|
|
|
model.train() |
|
num_epochs = 10 |
|
for epoch in range(num_epochs): |
|
for batch in dataloader: |
|
inputs, labels = batch |
|
- inputs, labels = inputs.to(device), labels.to(device) |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = torch.nn.functional.cross_entropy(outputs, labels) |
|
- loss.backward() |
|
+ fabric.backward(loss) |
|
optimizer.step() |
|
print(loss.data) |
|
``` |
|
|
|
</sub> |
|
<td> |
|
<sub> |
|
|
|
```Python |
|
import lightning as L |
|
import torch; import torchvision as tv |
|
|
|
dataset = tv.datasets.CIFAR10("data", download=True, |
|
train=True, |
|
transform=tv.transforms.ToTensor()) |
|
|
|
fabric = L.Fabric() |
|
fabric.launch() |
|
|
|
model = tv.models.resnet18() |
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.001) |
|
model, optimizer = fabric.setup(model, optimizer) |
|
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) |
|
dataloader = fabric.setup_dataloaders(dataloader) |
|
|
|
model.train() |
|
num_epochs = 10 |
|
for epoch in range(num_epochs): |
|
for batch in dataloader: |
|
inputs, labels = batch |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = torch.nn.functional.cross_entropy(outputs, labels) |
|
fabric.backward(loss) |
|
optimizer.step() |
|
print(loss.data) |
|
``` |
|
|
|
</sub> |
|
</td> |
|
</tr> |
|
</table> |
|
|
|
## Key features |
|
|
|
<details> |
|
<summary>Easily switch from running on CPU to GPU (Apple Silicon, CUDA, β¦), TPU, multi-GPU or even multi-node training</summary> |
|
|
|
```python |
|
# Use your available hardware |
|
# no code changes needed |
|
fabric = Fabric() |
|
|
|
# Run on GPUs (CUDA or MPS) |
|
fabric = Fabric(accelerator="gpu") |
|
|
|
# 8 GPUs |
|
fabric = Fabric(accelerator="gpu", devices=8) |
|
|
|
# 256 GPUs, multi-node |
|
fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32) |
|
|
|
# Run on TPUs |
|
fabric = Fabric(accelerator="tpu") |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box</summary> |
|
|
|
```python |
|
# Use state-of-the-art distributed training techniques |
|
fabric = Fabric(strategy="ddp") |
|
fabric = Fabric(strategy="deepspeed") |
|
fabric = Fabric(strategy="fsdp") |
|
|
|
# Switch the precision |
|
fabric = Fabric(precision="16-mixed") |
|
fabric = Fabric(precision="64") |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>All the device logic boilerplate is handled for you</summary> |
|
|
|
```diff |
|
# no more of this! |
|
- model.to(device) |
|
- batch.to(device) |
|
``` |
|
|
|
</details> |
|
|
|
<details> |
|
<summary>Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more</summary> |
|
|
|
```python |
|
import lightning as L |
|
|
|
|
|
class MyCustomTrainer: |
|
def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"): |
|
self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision) |
|
|
|
def fit(self, model, optimizer, dataloader, max_epochs): |
|
self.fabric.launch() |
|
|
|
model, optimizer = self.fabric.setup(model, optimizer) |
|
dataloader = self.fabric.setup_dataloaders(dataloader) |
|
model.train() |
|
|
|
for epoch in range(max_epochs): |
|
for batch in dataloader: |
|
input, target = batch |
|
optimizer.zero_grad() |
|
output = model(input) |
|
loss = loss_fn(output, target) |
|
self.fabric.backward(loss) |
|
optimizer.step() |
|
``` |
|
|
|
You can find a more extensive example in our [examples](examples/fabric/build_your_own_trainer) |
|
|
|
</details> |
|
|
|
______________________________________________________________________ |
|
|
|
<div align="center"> |
|
<a href="https://lightning.ai/docs/fabric/stable/">Read the Lightning Fabric docs</a> |
|
</div> |
|
|
|
______________________________________________________________________ |
|
|
|
|
|
|
|
|
|
## Examples |
|
|
|
###### Self-supervised Learning |
|
|
|
- [CPC transforms](https://lightning-bolts.readthedocs.io/en/stable/transforms/self_supervised.html#cpc-transforms) |
|
- [Moco v2 transforms](https://lightning-bolts.readthedocs.io/en/stable/transforms/self_supervised.html#moco-v2-transforms) |
|
- [SimCLR transforms](https://lightning-bolts.readthedocs.io/en/stable/transforms/self_supervised.html#simclr-transforms) |
|
|
|
###### Convolutional Architectures |
|
|
|
- [GPT-2](https://lightning-bolts.readthedocs.io/en/stable/models/convolutional.html#gpt-2) |
|
- [UNet](https://lightning-bolts.readthedocs.io/en/stable/models/convolutional.html#unet) |
|
|
|
###### Reinforcement Learning |
|
|
|
- [DQN Loss](https://lightning-bolts.readthedocs.io/en/stable/losses.html#dqn-loss) |
|
- [Double DQN Loss](https://lightning-bolts.readthedocs.io/en/stable/losses.html#double-dqn-loss) |
|
- [Per DQN Loss](https://lightning-bolts.readthedocs.io/en/stable/losses.html#per-dqn-loss) |
|
|
|
###### GANs |
|
|
|
- [Basic GAN](https://lightning-bolts.readthedocs.io/en/stable/models/gans.html#basic-gan) |
|
- [DCGAN](https://lightning-bolts.readthedocs.io/en/stable/models/gans.html#dcgan) |
|
|
|
###### Classic ML |
|
|
|
- [Logistic Regression](https://lightning-bolts.readthedocs.io/en/stable/models/classic_ml.html#logistic-regression) |
|
- [Linear Regression](https://lightning-bolts.readthedocs.io/en/stable/models/classic_ml.html#linear-regression) |
|
|
|
|
|
|
|
|
|
## Continuous Integration |
|
|
|
Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against major Python and PyTorch versions. |
|
|
|
###### \*Codecov is > 90%+ but build delays may show less |
|
|
|
<details> |
|
<summary>Current build statuses</summary> |
|
|
|
<center> |
|
|
|
| System / PyTorch ver. | 1.13 | 2.0 | 2.1 | |
|
| :--------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| |
|
| Linux py3.9 \[GPUs\] | | | [](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=24&branchName=master) | |
|
| Linux (multiple Python versions) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | |
|
| OSX (multiple Python versions) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | |
|
| Windows (multiple Python versions) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | |
|
|
|
</center> |
|
</details> |
|
|
|
|
|
|
|
|
|
## Community |
|
|
|
The lightning community is maintained by |
|
|
|
- [10+ core contributors](https://lightning.ai/docs/pytorch/latest/community/governance.html) who are all a mix of professional engineers, Research Scientists, and Ph.D. students from top AI labs. |
|
- 800+ community contributors. |
|
|
|
Want to help us build Lightning and reduce boilerplate for thousands of researchers? [Learn how to make your first contribution here](https://lightning.ai/docs/pytorch/stable/generated/CONTRIBUTING.html) |
|
|
|
Lightning is also part of the [PyTorch ecosystem](https://pytorch.org/ecosystem/) which requires projects to have solid testing, documentation and support. |
|
|
|
### Asking for help |
|
|
|
If you have any questions please: |
|
|
|
1. [Read the docs](https://lightning.ai/docs). |
|
1. [Search through existing Discussions](https://github.com/Lightning-AI/lightning/discussions), or [add a new question](https://github.com/Lightning-AI/lightning/discussions/new) |
|
1. [Join our discord](https://discord.com/invite/tfXFetEZxv). |
|
|