diff --git a/tests/test_learning.py b/tests/test_learning.py index 8a21d6462..932070820 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -165,6 +165,15 @@ def test_decision_tree_learner(): assert dTL([7.5, 4, 6, 2]) == "virginica" +def test_information_content(): + assert information_content([]) == 0 + assert information_content([4]) == 0 + assert information_content([5, 4, 0, 2, 5, 0]) > 1.9 + assert information_content([5, 4, 0, 2, 5, 0]) < 2 + assert information_content([1.5, 2.5]) > 0.9 + assert information_content([1.5, 2.5]) < 1.0 + + def test_random_forest(): iris = DataSet(name="iris") rF = RandomForest(iris)