Added F1 computation and updated reCTBN submodule

master
Meliurwen 2 years ago
parent 09be34d13a
commit e7b260750b
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 2
      deps/reCTBN
  2. 56
      src/main.rs

2
deps/reCTBN vendored

@ -1 +1 @@
Subproject commit e638a627bb1efb675d4242eff0bb543715b55ddc Subproject commit 0eb427e5cfda8b06cc4a1dc24725316b01237f75

@ -4,6 +4,7 @@ use csv;
use json; use json;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::fs; use std::fs;
use std::process::exit;
use std::time::Instant; use std::time::Instant;
use reCTBN::parameter_learning::MLE; use reCTBN::parameter_learning::MLE;
@ -85,7 +86,7 @@ where
Ok(()) Ok(())
} }
fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) { fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) -> CtbnNetwork {
// Initialize the hypothesis tests to pass to the CTPC with their // Initialize the hypothesis tests to pass to the CTPC with their
// respective significance level `alpha` // respective significance level `alpha`
let f = F::new(1e-6); let f = F::new(1e-6);
@ -95,17 +96,17 @@ fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) {
//Initialize CTPC //Initialize CTPC
let ctpc = CTPC::new(parameter_learning, f, chi_sq); let ctpc = CTPC::new(parameter_learning, f, chi_sq);
// Learn the structure of the network from the generated trajectory // Learn the structure of the network from the generated trajectory
ctpc.fit_transform(net, dataset); ctpc.fit_transform(net, dataset)
} }
fn main() { fn main() {
let file_path = "./networks-settings.json"; let file_path = "./networks-settings.json";
let csv_path = "./results.csv";
eprintln!("Opening file {}...", file_path); eprintln!("Opening file {}...", file_path);
let file_content = let file_content =
fs::read_to_string(file_path).expect("File not found! Check the README for instructions!"); fs::read_to_string(file_path).expect("File not found! Check the README for instructions!");
let parsed_json = json::parse(&file_content).unwrap(); let parsed_json = json::parse(&file_content).unwrap();
let csv_path = "./results.csv";
write_csv_record( write_csv_record(
csv_path, csv_path,
&[ &[
@ -113,6 +114,7 @@ fn main() {
"domain_cardinality", "domain_cardinality",
"density", "density",
"duration", "duration",
"f1_score",
], ],
) )
.unwrap(); .unwrap();
@ -136,13 +138,29 @@ fn main() {
parsed_json[idx]["cg_seed"].as_u64(), parsed_json[idx]["cg_seed"].as_u64(),
parsed_json[idx]["tg_seed"].as_u64(), parsed_json[idx]["tg_seed"].as_u64(),
); );
let strucure_original = net.get_adj_matrix().unwrap().clone();
let relevant_elements = strucure_original.iter().filter(|&elem| *elem == 1).count() as f64;
if relevant_elements == 0 as f64 {
eprintln!(
"[{}/{}] The network gnerated is empty! Change seeds for this network!",
idx + 1,
benchmarks_cardinality
);
exit(1)
}
eprintln!(
"[{}/{}] Original strucure:\n{:?}",
idx + 1,
benchmarks_cardinality,
strucure_original
);
eprintln!( eprintln!(
"[{}/{}] Structure learning CTPC...", "[{}/{}] Structure learning CTPC...",
idx + 1, idx + 1,
benchmarks_cardinality benchmarks_cardinality
); );
let start = Instant::now(); let start = Instant::now();
structure_learning_CTPC(net, &dataset); let net = structure_learning_CTPC(net, &dataset);
let duration = start.elapsed(); let duration = start.elapsed();
eprintln!( eprintln!(
"[{}/{}] Strucure learned with CTPC in {:?}.", "[{}/{}] Strucure learned with CTPC in {:?}.",
@ -150,6 +168,35 @@ fn main() {
benchmarks_cardinality, benchmarks_cardinality,
duration duration
); );
let structure_generated = net.get_adj_matrix().unwrap();
let retrieved_elements = structure_generated
.iter()
.filter(|&elem| *elem == 1)
.count() as f64;
eprintln!(
"[{}/{}] Discovered strucure:\n{}",
idx + 1,
benchmarks_cardinality,
structure_generated
);
let sum_structure = strucure_original + structure_generated;
eprintln!(
"[{}/{}] strucure_original + structure_generated value:\n{}",
idx + 1,
benchmarks_cardinality,
sum_structure
);
let true_positives = sum_structure.iter().filter(|&elem| *elem == 2).count() as f64;
// https://en.wikipedia.org/wiki/F-score
let precision = true_positives / retrieved_elements;
let recall = true_positives / relevant_elements;
let f1_score = 2.0 * (precision * recall) / (precision + recall);
eprintln!(
"[{}/{}] f1_score: {}",
idx + 1,
benchmarks_cardinality,
f1_score
);
write_csv_record( write_csv_record(
csv_path, csv_path,
&[ &[
@ -163,6 +210,7 @@ fn main() {
.to_string(), .to_string(),
parsed_json[idx]["density"].as_f64().unwrap().to_string(), parsed_json[idx]["density"].as_f64().unwrap().to_string(),
duration.as_millis().to_string(), duration.as_millis().to_string(),
f1_score.to_string(),
], ],
) )
.unwrap(); .unwrap();

Loading…
Cancel
Save