From e7b260750be206584c0221346fcf1496b409972d Mon Sep 17 00:00:00 2001 From: meliurwen Date: Thu, 16 Feb 2023 16:58:11 +0100 Subject: [PATCH] Added F1 computation and updated reCTBN submodule --- deps/reCTBN | 2 +- src/main.rs | 56 +++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/deps/reCTBN b/deps/reCTBN index e638a62..0eb427e 160000 --- a/deps/reCTBN +++ b/deps/reCTBN @@ -1 +1 @@ -Subproject commit e638a627bb1efb675d4242eff0bb543715b55ddc +Subproject commit 0eb427e5cfda8b06cc4a1dc24725316b01237f75 diff --git a/src/main.rs b/src/main.rs index d59d2d4..fee9bd8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use csv; use json; use std::collections::BTreeSet; use std::fs; +use std::process::exit; use std::time::Instant; use reCTBN::parameter_learning::MLE; @@ -85,7 +86,7 @@ where Ok(()) } -fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) { +fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) -> CtbnNetwork { // Initialize the hypothesis tests to pass to the CTPC with their // respective significance level `alpha` let f = F::new(1e-6); @@ -95,17 +96,17 @@ fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) { //Initialize CTPC let ctpc = CTPC::new(parameter_learning, f, chi_sq); // Learn the structure of the network from the generated trajectory - ctpc.fit_transform(net, dataset); + ctpc.fit_transform(net, dataset) } fn main() { let file_path = "./networks-settings.json"; - let csv_path = "./results.csv"; eprintln!("Opening file {}...", file_path); let file_content = fs::read_to_string(file_path).expect("File not found! Check the README for instructions!"); let parsed_json = json::parse(&file_content).unwrap(); + let csv_path = "./results.csv"; write_csv_record( csv_path, &[ @@ -113,6 +114,7 @@ fn main() { "domain_cardinality", "density", "duration", + "f1_score", ], ) .unwrap(); @@ -136,13 +138,29 @@ fn main() { parsed_json[idx]["cg_seed"].as_u64(), parsed_json[idx]["tg_seed"].as_u64(), ); + let strucure_original = net.get_adj_matrix().unwrap().clone(); + let relevant_elements = strucure_original.iter().filter(|&elem| *elem == 1).count() as f64; + if relevant_elements == 0 as f64 { + eprintln!( + "[{}/{}] The network gnerated is empty! Change seeds for this network!", + idx + 1, + benchmarks_cardinality + ); + exit(1) + } + eprintln!( + "[{}/{}] Original strucure:\n{:?}", + idx + 1, + benchmarks_cardinality, + strucure_original + ); eprintln!( "[{}/{}] Structure learning CTPC...", idx + 1, benchmarks_cardinality ); let start = Instant::now(); - structure_learning_CTPC(net, &dataset); + let net = structure_learning_CTPC(net, &dataset); let duration = start.elapsed(); eprintln!( "[{}/{}] Strucure learned with CTPC in {:?}.", @@ -150,6 +168,35 @@ fn main() { benchmarks_cardinality, duration ); + let structure_generated = net.get_adj_matrix().unwrap(); + let retrieved_elements = structure_generated + .iter() + .filter(|&elem| *elem == 1) + .count() as f64; + eprintln!( + "[{}/{}] Discovered strucure:\n{}", + idx + 1, + benchmarks_cardinality, + structure_generated + ); + let sum_structure = strucure_original + structure_generated; + eprintln!( + "[{}/{}] strucure_original + structure_generated value:\n{}", + idx + 1, + benchmarks_cardinality, + sum_structure + ); + let true_positives = sum_structure.iter().filter(|&elem| *elem == 2).count() as f64; + // https://en.wikipedia.org/wiki/F-score + let precision = true_positives / retrieved_elements; + let recall = true_positives / relevant_elements; + let f1_score = 2.0 * (precision * recall) / (precision + recall); + eprintln!( + "[{}/{}] f1_score: {}", + idx + 1, + benchmarks_cardinality, + f1_score + ); write_csv_record( csv_path, &[ @@ -163,6 +210,7 @@ fn main() { .to_string(), parsed_json[idx]["density"].as_f64().unwrap().to_string(), duration.as_millis().to_string(), + f1_score.to_string(), ], ) .unwrap();