pytorch
PyTorchModel(module, output_names, batch_size=32, num_workers=1)
Bases: Model
PyTorch model wrapper.
Initialize the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module
|
Module
|
The PyTorch module. |
required |
output_names
|
list[str]
|
The names of the output columns. |
required |
batch_size
|
int
|
The batch size to use for predictions. |
32
|
num_workers
|
int
|
The number of workers to use for the DataLoader. |
1
|
Source code in src/flowcean/models/pytorch.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
|