Skip to content

pytorch

PyTorchModel(module, output_names, batch_size=32, num_workers=1)

Bases: Model

PyTorch model wrapper.

Initialize the model.

Parameters:

Name Type Description Default
module Module

The PyTorch module.

required
output_names list[str]

The names of the output columns.

required
batch_size int

The batch size to use for predictions.

32
num_workers int

The number of workers to use for the DataLoader.

1
Source code in src/flowcean/models/pytorch.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(
    self,
    module: Module,
    output_names: list[str],
    batch_size: int = 32,
    num_workers: int = 1,
) -> None:
    """Initialize the model.

    Args:
        module: The PyTorch module.
        output_names: The names of the output columns.
        batch_size: The batch size to use for predictions.
        num_workers: The number of workers to use for the DataLoader.
    """
    self.module = module
    self.output_names = output_names
    self.batch_size = batch_size
    self.num_workers = num_workers