Skip to content

tool

Stop

Bases: Exception

Exception raised when the tool loop should be stopped.

start_prediction_loop(model, adapter, *, adapter_to_model_transforms=None, model_to_adapter_transforms=None)

Start a prediction loop with the given model and adapter.

Source code in src/flowcean/core/tool/predict.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def start_prediction_loop(
    model: Model,
    adapter: Adapter,
    *,
    adapter_to_model_transforms: Transform | None = None,
    model_to_adapter_transforms: Transform | None = None,
) -> None:
    """Start a prediction loop with the given model and adapter."""
    adapter_to_model_transforms = adapter_to_model_transforms or Identity()
    model_to_adapter_transforms = model_to_adapter_transforms or Identity()

    # Start the adapter
    adapter.start()

    # Run the prediction loop
    try:
        while True:
            # Get data from the adapter
            data = adapter.get_data()
            # Transform the data to the model format
            transformed_data = adapter_to_model_transforms(data)
            # Run the model prediction
            prediction = model.predict(transformed_data)
            # Transform the prediction to the adapter format
            transformed_prediction = model_to_adapter_transforms(
                prediction,
            )
            # Send the prediction to the adapter
            adapter.send_data(transformed_prediction)
    except Stop:
        pass
    except KeyboardInterrupt:
        pass
    finally:
        # Stop the adapter
        adapter.stop()

test_model(model, test_data, predicate, *, show_progress=False, stop_after=1)

Test a model with the given test data and predicate.

This function runs the model on the test data and checks if the predictions satisfy the given predicate. If any prediction does not satisfy the predicate, a TestFailed exception is raised. This exception contains the input data and prediction that failed the predicate and can be used as a counterexample. This method relies on the model's predict method to obtain a prediction. It does not utilize the model's type or internal structure to prove predicates.

Parameters:

Name Type Description Default
model Model

The model to test.

required
test_data IncrementalEnvironment

The test data to use for testing the model. This must only include input features passed to the model and not the targets.

required
predicate Predicate

The predicate used to check the model's predictions.

required
show_progress bool

Whether to show progress during testing. Defaults to False.

False
stop_after int

Number of tests that need to fail before stopping. Defaults to 1. If set to 0 or negative, all tests are run regardless of failures.

1

Raises:

Type Description
TestFailed

If the model's prediction does not satisfy the predicate.

Source code in src/flowcean/core/tool/test.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def test_model(
    model: Model,
    test_data: IncrementalEnvironment,
    predicate: Predicate,
    *,
    show_progress: bool = False,
    stop_after: int = 1,
) -> None:
    """Test a model with the given test data and predicate.

    This function runs the model on the test data and checks if the
    predictions satisfy the given predicate. If any prediction does not
    satisfy the predicate, a TestFailed exception is raised.
    This exception contains the input data and prediction that failed the
    predicate and can be used as a counterexample.
    This method relies on the model's predict method to obtain a prediction.
    It does not utilize the model's type or internal structure to prove
    predicates.

    Args:
        model: The model to test.
        test_data: The test data to use for testing the model. This must only
            include input features passed to the model and *not* the targets.
        predicate: The predicate used to check the model's predictions.
        show_progress: Whether to show progress during testing.
            Defaults to False.
        stop_after: Number of tests that need to fail before stopping. Defaults
            to 1. If set to 0 or negative, all tests are run regardless of
            failures.

    Raises:
        TestFailed: If the model's prediction does not satisfy the
            predicate.
    """
    number_of_failures = 0
    failure_data: list[Data] = []
    failure_prediction: list[Data] = []
    # Run the model on the test data
    for input_data in (
        tqdm.tqdm(
            test_data,
            "Testing Model",
            total=test_data.num_steps(),
        )
        if show_progress
        else test_data
    ):
        prediction = model.predict(input_data)

        # Handle dataframes and lazyframes separately
        # Those may contain multiple rows / samples and need to be
        # sliced to get the individual samples for testing
        if isinstance(prediction, pl.LazyFrame | pl.DataFrame) and isinstance(
            input_data,
            pl.LazyFrame | pl.DataFrame,
        ):
            input_data_collected = input_data.lazy().collect()
            prediction = prediction.lazy().collect()

            test_inputs = [
                input_data_collected.slice(i, 1)
                for i in range(len(input_data_collected))
            ]
            predictions = [
                prediction.slice(i, 1) for i in range(len(prediction))
            ]

        else:
            test_inputs = [input_data]
            predictions = [prediction]

        # Check if the prediction satisfies the predicate
        for test_input, prediction in zip(
            test_inputs,
            predictions,
            strict=True,
        ):
            if not predicate(
                test_input,
                prediction,
            ):
                number_of_failures += 1
                failure_data.append(test_input)
                failure_prediction.append(prediction)

                if number_of_failures >= stop_after > 0:
                    break

    # If we got any failures at this point, raise an exception
    if number_of_failures > 0:
        raise TestFailed(failure_data, failure_prediction)