From 8ce2d6ee049606f9fc02dce75c34590c8f909253 Mon Sep 17 00:00:00 2001 From: Luca Moretti Date: Mon, 12 Oct 2020 17:41:07 +0200 Subject: [PATCH] Added precision and recall --- .../structure_score_based_estimator.py | 32 ++++++++++++++----- .../test_structure_score_based_estimator.py | 2 +- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/main_package/classes/estimators/structure_score_based_estimator.py b/main_package/classes/estimators/structure_score_based_estimator.py index 1816f0b..472df93 100644 --- a/main_package/classes/estimators/structure_score_based_estimator.py +++ b/main_package/classes/estimators/structure_score_based_estimator.py @@ -62,7 +62,7 @@ class StructureScoreBasedEstimator(se.StructureEstimator): """ 'Save the true edges structure in tuples' true_edges = copy.deepcopy(self.sample_path.structure.edges) - true_edges = list(map(tuple, true_edges)) + true_edges = set(map(tuple, true_edges)) 'Remove all the edges from the structure' self.sample_path.structure.clean_structure_edges() @@ -91,23 +91,39 @@ class StructureScoreBasedEstimator(se.StructureEstimator): print('-------------------------') 'TODO: Pensare a un modo migliore -- set difference sembra non funzionare ' + n_missing_edges = 0 n_added_fake_edges = 0 - for estimate_edge in list_edges: - if not estimate_edge in true_edges: - n_added_fake_edges += 1 + set_list_edges = set(list_edges) + + n_added_fake_edges = len(set_list_edges.difference(true_edges)) + + n_missing_edges = len(true_edges.difference(set_list_edges)) + + n_true_positive = len(true_edges) - n_missing_edges + + precision = n_true_positive / (n_true_positive + n_added_fake_edges) + + recall = n_true_positive / (n_true_positive + n_missing_edges) + + + # for estimate_edge in list_edges: + # if not estimate_edge in true_edges: + # n_added_fake_edges += 1 - for true_edge in true_edges: - if not true_edge in list_edges: - n_missing_edges += 1 - print(true_edge) + # for true_edge in true_edges: + # if not true_edge in list_edges: + # n_missing_edges += 1 + # print(true_edge) print(f"n archi reali non trovati: {n_missing_edges}") print(f"n archi non reali aggiunti: {n_added_fake_edges}") print(true_edges) print(list_edges) + print(f"precision: {precision} ") + print(f"recall: {recall} ") def estimate_parents(self,node_id:str, max_parents:int = None, iterations_number:int= 40, patience:int = 10 ): diff --git a/main_package/tests/estimators/test_structure_score_based_estimator.py b/main_package/tests/estimators/test_structure_score_based_estimator.py index b4acb1b..4ad74ec 100644 --- a/main_package/tests/estimators/test_structure_score_based_estimator.py +++ b/main_package/tests/estimators/test_structure_score_based_estimator.py @@ -32,7 +32,7 @@ class TestStructureScoreBasedEstimator(unittest.TestCase): def test_esecuzione(self): se1 = se.StructureScoreBasedEstimator(self.s1) se1.estimate_structure( - max_parents = 6, + max_parents = 3, iterations_number = 80, patience = None )