Skip to content

lightning_learner

LightningCallbackBridge(learner, callback_manager)

Bases: Callback

Bridge between PyTorch Lightning callbacks and flowcean callbacks.

This adapter forwards Lightning training events to flowcean callbacks.

Source code in src/flowcean/torch/lightning_learner.py
28
29
30
31
32
33
34
35
def __init__(
    self,
    learner: Named,
    callback_manager: Any,
) -> None:
    super().__init__()
    self.learner = learner
    self.callback_manager = callback_manager

on_train_start(trainer, pl_module)

Called when training starts.

Source code in src/flowcean/torch/lightning_learner.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def on_train_start(
    self,
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,  # noqa: ARG002
) -> None:
    """Called when training starts."""
    context = {
        "max_epochs": trainer.max_epochs,
        "batch_size": (
            trainer.train_dataloader.batch_size  # type: ignore[union-attr]
            if trainer.train_dataloader
            else None
        ),
    }
    self.callback_manager.on_learning_start(self.learner, context)

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

Called after each training batch.

Source code in src/flowcean/torch/lightning_learner.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def on_train_batch_end(
    self,
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs: Any,  # noqa: ARG002
    batch: Any,  # noqa: ARG002
    batch_idx: int,
) -> None:
    """Called after each training batch."""
    # Calculate progress based on current epoch and batch
    if trainer.max_epochs and trainer.num_training_batches:
        current_epoch = trainer.current_epoch
        progress = (
            current_epoch + (batch_idx + 1) / trainer.num_training_batches
        ) / trainer.max_epochs
    else:
        progress = None

    # Extract metrics from logged values
    metrics = {
        "epoch": trainer.current_epoch + 1,
        "batch": batch_idx + 1,
    }

    # Add loss if available
    has_callback_metrics = (
        hasattr(pl_module, "trainer")
        and pl_module.trainer.callback_metrics
    )
    if has_callback_metrics:
        for key, value in pl_module.trainer.callback_metrics.items():
            if hasattr(value, "item"):
                metrics[key] = value.item()  # type: ignore[assignment]

    self.callback_manager.on_learning_progress(
        self.learner,
        progress=progress,
        metrics=metrics,
    )

on_train_end(trainer, pl_module)

Called when training ends.

Source code in src/flowcean/torch/lightning_learner.py
93
94
95
96
97
98
def on_train_end(
    self,
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
) -> None:
    """Called when training ends."""

LightningLearner(module, num_workers=None, batch_size=32, max_epochs=100, accelerator='auto', callbacks=None)

Bases: SupervisedLearner

A learner that uses PyTorch Lightning.

Parameters:

Name Type Description Default
module LightningModule

The PyTorch Lightning module.

required
num_workers int | None

The number of workers to use for the DataLoader.

None
batch_size int

The batch size to use for training.

32
max_epochs int

The maximum number of epochs to train for.

100
accelerator str

The accelerator to use.

'auto'
callbacks list[LearnerCallback] | LearnerCallback | None

Optional callbacks for progress feedback. Use None for silent learning.

None

Initialize the learner.

Parameters:

Name Type Description Default
module LightningModule

The PyTorch Lightning module.

required
num_workers int | None

The number of workers to use for the DataLoader.

None
batch_size int

The batch size to use for training.

32
max_epochs int

The maximum number of epochs to train for.

100
accelerator str

The accelerator to use.

'auto'
callbacks list[LearnerCallback] | LearnerCallback | None

Optional callbacks for progress feedback. Use None for silent learning.

None
Source code in src/flowcean/torch/lightning_learner.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    module: lightning.LightningModule,
    num_workers: int | None = None,
    batch_size: int = 32,
    max_epochs: int = 100,
    accelerator: str = "auto",
    callbacks: list[LearnerCallback] | LearnerCallback | None = None,
) -> None:
    """Initialize the learner.

    Args:
        module: The PyTorch Lightning module.
        num_workers: The number of workers to use for the DataLoader.
        batch_size: The batch size to use for training.
        max_epochs: The maximum number of epochs to train for.
        accelerator: The accelerator to use.
        callbacks: Optional callbacks for progress feedback. Use `None`
            for silent learning.
    """
    self.module = module
    self.num_workers = num_workers or os.cpu_count() or 0
    self.max_epochs = max_epochs
    self.batch_size = batch_size
    self.optimizer = None
    self.accelerator = accelerator
    self.callback_manager = create_callback_manager(callbacks)