Skip to content

adaboost_classifier

AdaBoost(*args, base_estimator=None, n_estimators=50, learning_rate=1.0, **kwargs)

Bases: SupervisedLearner

Wrapper class for sklearn's AdaBoostClassifier.

Reference: https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.AdaBoostClassifier.html

Initialize the AdaBoost classifier learner.

Parameters:

Name Type Description Default
*args Any

Positional arguments to pass to the AdaBoostClassifier.

()
base_estimator object

The base estimator from which the boosted ensemble is built. If None, then the base estimator is DecisionTreeClassifier(max_depth=1).

None
n_estimators int

The maximum number of estimators at which boosting is terminated. Defaults to 50.

50
learning_rate float

Learning rate shrinks the contribution of each classifier. Defaults to 1.0.

1.0
**kwargs Any

Keyword arguments to pass to the AdaBoostClassifier.

{}
Source code in src/flowcean/sklearn/adaboost_classifier.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    *args: Any,
    base_estimator: object = None,
    n_estimators: int = 50,
    learning_rate: float = 1.0,
    **kwargs: Any,
) -> None:
    """Initialize the AdaBoost classifier learner.

    Args:
        *args: Positional arguments to pass to the AdaBoostClassifier.
        base_estimator: The base estimator from which the boosted ensemble
            is built. If None, then the base estimator is
            DecisionTreeClassifier(max_depth=1).
        n_estimators: The maximum number of estimators at which boosting is
            terminated. Defaults to 50.
        learning_rate: Learning rate shrinks the contribution of each
            classifier. Defaults to 1.0.
        **kwargs: Keyword arguments to pass to the AdaBoostClassifier.
    """
    self.classifier = AdaBoostClassifier(
        *args,
        estimator=base_estimator,
        n_estimators=n_estimators,
        learning_rate=learning_rate,
        random_state=get_seed(),
        **kwargs,
    )