Added F1 computation and updated reCTBN submodule

master
Meliurwen 2 years ago
parent 09be34d13a
commit e7b260750b
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 2
      deps/reCTBN
  2. 56
      src/main.rs

2
deps/reCTBN vendored

@ -1 +1 @@
Subproject commit e638a627bb1efb675d4242eff0bb543715b55ddc
Subproject commit 0eb427e5cfda8b06cc4a1dc24725316b01237f75

@ -4,6 +4,7 @@ use csv;
use json;
use std::collections::BTreeSet;
use std::fs;
use std::process::exit;
use std::time::Instant;
use reCTBN::parameter_learning::MLE;
@ -85,7 +86,7 @@ where
Ok(())
}
fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) {
fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) -> CtbnNetwork {
// Initialize the hypothesis tests to pass to the CTPC with their
// respective significance level `alpha`
let f = F::new(1e-6);
@ -95,17 +96,17 @@ fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) {
//Initialize CTPC
let ctpc = CTPC::new(parameter_learning, f, chi_sq);
// Learn the structure of the network from the generated trajectory
ctpc.fit_transform(net, dataset);
ctpc.fit_transform(net, dataset)
}
fn main() {
let file_path = "./networks-settings.json";
let csv_path = "./results.csv";
eprintln!("Opening file {}...", file_path);
let file_content =
fs::read_to_string(file_path).expect("File not found! Check the README for instructions!");
let parsed_json = json::parse(&file_content).unwrap();
let csv_path = "./results.csv";
write_csv_record(
csv_path,
&[
@ -113,6 +114,7 @@ fn main() {
"domain_cardinality",
"density",
"duration",
"f1_score",
],
)
.unwrap();
@ -136,13 +138,29 @@ fn main() {
parsed_json[idx]["cg_seed"].as_u64(),
parsed_json[idx]["tg_seed"].as_u64(),
);
let strucure_original = net.get_adj_matrix().unwrap().clone();
let relevant_elements = strucure_original.iter().filter(|&elem| *elem == 1).count() as f64;
if relevant_elements == 0 as f64 {
eprintln!(
"[{}/{}] The network gnerated is empty! Change seeds for this network!",
idx + 1,
benchmarks_cardinality
);
exit(1)
}
eprintln!(
"[{}/{}] Original strucure:\n{:?}",
idx + 1,
benchmarks_cardinality,
strucure_original
);
eprintln!(
"[{}/{}] Structure learning CTPC...",
idx + 1,
benchmarks_cardinality
);
let start = Instant::now();
structure_learning_CTPC(net, &dataset);
let net = structure_learning_CTPC(net, &dataset);
let duration = start.elapsed();
eprintln!(
"[{}/{}] Strucure learned with CTPC in {:?}.",
@ -150,6 +168,35 @@ fn main() {
benchmarks_cardinality,
duration
);
let structure_generated = net.get_adj_matrix().unwrap();
let retrieved_elements = structure_generated
.iter()
.filter(|&elem| *elem == 1)
.count() as f64;
eprintln!(
"[{}/{}] Discovered strucure:\n{}",
idx + 1,
benchmarks_cardinality,
structure_generated
);
let sum_structure = strucure_original + structure_generated;
eprintln!(
"[{}/{}] strucure_original + structure_generated value:\n{}",
idx + 1,
benchmarks_cardinality,
sum_structure
);
let true_positives = sum_structure.iter().filter(|&elem| *elem == 2).count() as f64;
// https://en.wikipedia.org/wiki/F-score
let precision = true_positives / retrieved_elements;
let recall = true_positives / relevant_elements;
let f1_score = 2.0 * (precision * recall) / (precision + recall);
eprintln!(
"[{}/{}] f1_score: {}",
idx + 1,
benchmarks_cardinality,
f1_score
);
write_csv_record(
csv_path,
&[
@ -163,6 +210,7 @@ fn main() {
.to_string(),
parsed_json[idx]["density"].as_f64().unwrap().to_string(),
duration.as_millis().to_string(),
f1_score.to_string(),
],
)
.unwrap();

Loading…
Cancel
Save