|
|
@ -4,6 +4,7 @@ use csv; |
|
|
|
use json; |
|
|
|
use json; |
|
|
|
use std::collections::BTreeSet; |
|
|
|
use std::collections::BTreeSet; |
|
|
|
use std::fs; |
|
|
|
use std::fs; |
|
|
|
|
|
|
|
use std::process::exit; |
|
|
|
use std::time::Instant; |
|
|
|
use std::time::Instant; |
|
|
|
|
|
|
|
|
|
|
|
use reCTBN::parameter_learning::MLE; |
|
|
|
use reCTBN::parameter_learning::MLE; |
|
|
@ -85,7 +86,7 @@ where |
|
|
|
Ok(()) |
|
|
|
Ok(()) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) { |
|
|
|
fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) -> CtbnNetwork { |
|
|
|
// Initialize the hypothesis tests to pass to the CTPC with their
|
|
|
|
// Initialize the hypothesis tests to pass to the CTPC with their
|
|
|
|
// respective significance level `alpha`
|
|
|
|
// respective significance level `alpha`
|
|
|
|
let f = F::new(1e-6); |
|
|
|
let f = F::new(1e-6); |
|
|
@ -95,17 +96,17 @@ fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) { |
|
|
|
//Initialize CTPC
|
|
|
|
//Initialize CTPC
|
|
|
|
let ctpc = CTPC::new(parameter_learning, f, chi_sq); |
|
|
|
let ctpc = CTPC::new(parameter_learning, f, chi_sq); |
|
|
|
// Learn the structure of the network from the generated trajectory
|
|
|
|
// Learn the structure of the network from the generated trajectory
|
|
|
|
ctpc.fit_transform(net, dataset); |
|
|
|
ctpc.fit_transform(net, dataset) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn main() { |
|
|
|
fn main() { |
|
|
|
let file_path = "./networks-settings.json"; |
|
|
|
let file_path = "./networks-settings.json"; |
|
|
|
let csv_path = "./results.csv"; |
|
|
|
|
|
|
|
eprintln!("Opening file {}...", file_path); |
|
|
|
eprintln!("Opening file {}...", file_path); |
|
|
|
let file_content = |
|
|
|
let file_content = |
|
|
|
fs::read_to_string(file_path).expect("File not found! Check the README for instructions!"); |
|
|
|
fs::read_to_string(file_path).expect("File not found! Check the README for instructions!"); |
|
|
|
let parsed_json = json::parse(&file_content).unwrap(); |
|
|
|
let parsed_json = json::parse(&file_content).unwrap(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let csv_path = "./results.csv"; |
|
|
|
write_csv_record( |
|
|
|
write_csv_record( |
|
|
|
csv_path, |
|
|
|
csv_path, |
|
|
|
&[ |
|
|
|
&[ |
|
|
@ -113,6 +114,7 @@ fn main() { |
|
|
|
"domain_cardinality", |
|
|
|
"domain_cardinality", |
|
|
|
"density", |
|
|
|
"density", |
|
|
|
"duration", |
|
|
|
"duration", |
|
|
|
|
|
|
|
"f1_score", |
|
|
|
], |
|
|
|
], |
|
|
|
) |
|
|
|
) |
|
|
|
.unwrap(); |
|
|
|
.unwrap(); |
|
|
@ -136,13 +138,29 @@ fn main() { |
|
|
|
parsed_json[idx]["cg_seed"].as_u64(), |
|
|
|
parsed_json[idx]["cg_seed"].as_u64(), |
|
|
|
parsed_json[idx]["tg_seed"].as_u64(), |
|
|
|
parsed_json[idx]["tg_seed"].as_u64(), |
|
|
|
); |
|
|
|
); |
|
|
|
|
|
|
|
let strucure_original = net.get_adj_matrix().unwrap().clone(); |
|
|
|
|
|
|
|
let relevant_elements = strucure_original.iter().filter(|&elem| *elem == 1).count() as f64; |
|
|
|
|
|
|
|
if relevant_elements == 0 as f64 { |
|
|
|
|
|
|
|
eprintln!( |
|
|
|
|
|
|
|
"[{}/{}] The network gnerated is empty! Change seeds for this network!", |
|
|
|
|
|
|
|
idx + 1, |
|
|
|
|
|
|
|
benchmarks_cardinality |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
exit(1) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
eprintln!( |
|
|
|
|
|
|
|
"[{}/{}] Original strucure:\n{:?}", |
|
|
|
|
|
|
|
idx + 1, |
|
|
|
|
|
|
|
benchmarks_cardinality, |
|
|
|
|
|
|
|
strucure_original |
|
|
|
|
|
|
|
); |
|
|
|
eprintln!( |
|
|
|
eprintln!( |
|
|
|
"[{}/{}] Structure learning CTPC...", |
|
|
|
"[{}/{}] Structure learning CTPC...", |
|
|
|
idx + 1, |
|
|
|
idx + 1, |
|
|
|
benchmarks_cardinality |
|
|
|
benchmarks_cardinality |
|
|
|
); |
|
|
|
); |
|
|
|
let start = Instant::now(); |
|
|
|
let start = Instant::now(); |
|
|
|
structure_learning_CTPC(net, &dataset); |
|
|
|
let net = structure_learning_CTPC(net, &dataset); |
|
|
|
let duration = start.elapsed(); |
|
|
|
let duration = start.elapsed(); |
|
|
|
eprintln!( |
|
|
|
eprintln!( |
|
|
|
"[{}/{}] Strucure learned with CTPC in {:?}.", |
|
|
|
"[{}/{}] Strucure learned with CTPC in {:?}.", |
|
|
@ -150,6 +168,35 @@ fn main() { |
|
|
|
benchmarks_cardinality, |
|
|
|
benchmarks_cardinality, |
|
|
|
duration |
|
|
|
duration |
|
|
|
); |
|
|
|
); |
|
|
|
|
|
|
|
let structure_generated = net.get_adj_matrix().unwrap(); |
|
|
|
|
|
|
|
let retrieved_elements = structure_generated |
|
|
|
|
|
|
|
.iter() |
|
|
|
|
|
|
|
.filter(|&elem| *elem == 1) |
|
|
|
|
|
|
|
.count() as f64; |
|
|
|
|
|
|
|
eprintln!( |
|
|
|
|
|
|
|
"[{}/{}] Discovered strucure:\n{}", |
|
|
|
|
|
|
|
idx + 1, |
|
|
|
|
|
|
|
benchmarks_cardinality, |
|
|
|
|
|
|
|
structure_generated |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
let sum_structure = strucure_original + structure_generated; |
|
|
|
|
|
|
|
eprintln!( |
|
|
|
|
|
|
|
"[{}/{}] strucure_original + structure_generated value:\n{}", |
|
|
|
|
|
|
|
idx + 1, |
|
|
|
|
|
|
|
benchmarks_cardinality, |
|
|
|
|
|
|
|
sum_structure |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
let true_positives = sum_structure.iter().filter(|&elem| *elem == 2).count() as f64; |
|
|
|
|
|
|
|
// https://en.wikipedia.org/wiki/F-score
|
|
|
|
|
|
|
|
let precision = true_positives / retrieved_elements; |
|
|
|
|
|
|
|
let recall = true_positives / relevant_elements; |
|
|
|
|
|
|
|
let f1_score = 2.0 * (precision * recall) / (precision + recall); |
|
|
|
|
|
|
|
eprintln!( |
|
|
|
|
|
|
|
"[{}/{}] f1_score: {}", |
|
|
|
|
|
|
|
idx + 1, |
|
|
|
|
|
|
|
benchmarks_cardinality, |
|
|
|
|
|
|
|
f1_score |
|
|
|
|
|
|
|
); |
|
|
|
write_csv_record( |
|
|
|
write_csv_record( |
|
|
|
csv_path, |
|
|
|
csv_path, |
|
|
|
&[ |
|
|
|
&[ |
|
|
@ -163,6 +210,7 @@ fn main() { |
|
|
|
.to_string(), |
|
|
|
.to_string(), |
|
|
|
parsed_json[idx]["density"].as_f64().unwrap().to_string(), |
|
|
|
parsed_json[idx]["density"].as_f64().unwrap().to_string(), |
|
|
|
duration.as_millis().to_string(), |
|
|
|
duration.as_millis().to_string(), |
|
|
|
|
|
|
|
f1_score.to_string(), |
|
|
|
], |
|
|
|
], |
|
|
|
) |
|
|
|
) |
|
|
|
.unwrap(); |
|
|
|
.unwrap(); |
|
|
|