Skip to content

train_test_split

TrainTestSplit(ratio, *, shuffle=False)

Split data into train and test sets.

Initialize the train-test splitter.

Parameters:

Name Type Description Default
ratio float

The ratio of the data to put in the training set.

required
shuffle bool

Whether to shuffle the data before splitting.

False
Source code in src/flowcean/environments/train_test_split.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    ratio: float,
    *,
    shuffle: bool = False,
) -> None:
    """Initialize the train-test splitter.

    Args:
        ratio: The ratio of the data to put in the training set.
        shuffle: Whether to shuffle the data before splitting.
    """
    if ratio < 0 or ratio > 1:
        message = "ratio must be between 0 and 1"
        raise ValueError(message)
    self.ratio = ratio
    self.shuffle = shuffle

split(environment)

Split the data into train and test sets.

Parameters:

Name Type Description Default
environment OfflineEnvironment

The environment to split.

required
Source code in src/flowcean/environments/train_test_split.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def split(
    self,
    environment: OfflineEnvironment,
) -> tuple[Dataset, Dataset]:
    """Split the data into train and test sets.

    Args:
        environment: The environment to split.
    """
    logger.info("Splitting data into train and test sets")
    data = environment.observe().collect(streaming=True)
    pivot = int(len(data) * self.ratio)
    splits = _split(
        data,
        lengths=[pivot, len(data) - pivot],
        shuffle=self.shuffle,
        seed=get_seed(),
    )
    return Dataset(splits[0].lazy()), Dataset(splits[1].lazy())