Skip to content

lightning_learner

LightningLearner(module, num_workers=None, batch_size=32, max_epochs=100, accelerator='auto')

Bases: SupervisedLearner

A learner that uses PyTorch Lightning.

Initialize the learner.

Parameters:

Name Type Description Default
module LightningModule

The PyTorch Lightning module.

required
num_workers int | None

The number of workers to use for the DataLoader.

None
batch_size int

The batch size to use for training.

32
max_epochs int

The maximum number of epochs to train for.

100
accelerator str

The accelerator to use.

'auto'
Source code in src/flowcean/torch/lightning_learner.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    module: lightning.LightningModule,
    num_workers: int | None = None,
    batch_size: int = 32,
    max_epochs: int = 100,
    accelerator: str = "auto",
) -> None:
    """Initialize the learner.

    Args:
        module: The PyTorch Lightning module.
        num_workers: The number of workers to use for the DataLoader.
        batch_size: The batch size to use for training.
        max_epochs: The maximum number of epochs to train for.
        accelerator: The accelerator to use.
    """
    self.module = module
    self.num_workers = num_workers or os.cpu_count() or 0
    self.max_epochs = max_epochs
    self.batch_size = batch_size
    self.optimizer = None
    self.accelerator = accelerator