Skip to content

regression_tree

RegressionTree(*, dot_graph_export_path=None, criterion='squared_error', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, ccp_alpha=0.0, monotonic_cst=None)

Bases: SupervisedLearner

Wrapper class for sklearn's DecisionTreeRegressor.

Reference: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html

Initialize the regression tree learner.

Reference: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html

Source code in src/flowcean/sklearn/regression_tree.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(
    self,
    *,
    dot_graph_export_path: None | str = None,
    criterion: str = "squared_error",
    splitter: str = "best",
    max_depth: int | None = None,
    min_samples_split: int = 2,
    min_samples_leaf: int = 1,
    min_weight_fraction_leaf: float = 0.0,
    max_features: float | None = None,
    random_state: int | None = None,
    max_leaf_nodes: int | None = None,
    min_impurity_decrease: float = 0.0,
    ccp_alpha: float = 0.0,
    monotonic_cst: NDArray | None = None,
) -> None:
    """Initialize the regression tree learner.

    Reference: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html
    """
    self.regressor = DecisionTreeRegressor(
        criterion=criterion,
        splitter=splitter,
        max_depth=max_depth,
        min_samples_split=min_samples_split,
        min_samples_leaf=min_samples_leaf,
        min_weight_fraction_leaf=min_weight_fraction_leaf,
        max_features=max_features,
        max_leaf_nodes=max_leaf_nodes,
        min_impurity_decrease=min_impurity_decrease,
        random_state=random_state or get_seed(),
        ccp_alpha=ccp_alpha,
        monotonic_cst=monotonic_cst,
    )
    self.dot_graph_export_path = dot_graph_export_path