Merge branch '35-refactor-constructor' into 'dev'

pull/44/head
Meliurwen 3 years ago
commit 74039dac94
Signed by: meliurwen
GPG Key ID: 818A8B35E9F1CE10
  1. 12
      src/ctbn.rs
  2. 2
      src/node.rs
  3. 2
      src/params.rs
  4. 2
      src/structure_learning/score_based_algorithm.rs
  5. 6
      src/structure_learning/score_function.rs
  6. 8
      src/tools.rs
  7. 12
      tests/ctbn.rs
  8. 8
      tests/parameter_learning.rs
  9. 52
      tests/structure_learning.rs
  10. 10
      tests/tools.rs
  11. 4
      tests/utils.rs

@ -29,19 +29,19 @@ use std::collections::BTreeSet;
/// domain.insert(String::from("B"));
///
/// //Create the parameters for a discrete node using the domain
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
/// let param = params::DiscreteStatesContinousTimeParams::new(domain);
///
/// //Create the node using the parameters
/// let X1 = node::Node::init(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
/// let X1 = node::Node::new(params::Params::DiscreteStatesContinousTime(param),String::from("X1"));
///
/// let mut domain = BTreeSet::new();
/// domain.insert(String::from("A"));
/// domain.insert(String::from("B"));
/// let param = params::DiscreteStatesContinousTimeParams::init(domain);
/// let X2 = node::Node::init(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
/// let param = params::DiscreteStatesContinousTimeParams::new(domain);
/// let X2 = node::Node::new(params::Params::DiscreteStatesContinousTime(param), String::from("X2"));
///
/// //Initialize a ctbn
/// let mut net = CtbnNetwork::init();
/// let mut net = CtbnNetwork::new();
///
/// //Add nodes
/// let X1 = net.add_node(X1).unwrap();
@ -61,7 +61,7 @@ pub struct CtbnNetwork {
impl CtbnNetwork {
pub fn init() -> CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new()

@ -7,7 +7,7 @@ pub struct Node {
}
impl Node {
pub fn init(params: Params, label: String) -> Node {
pub fn new(params: Params, label: String) -> Node {
Node{
params: params,
label:label

@ -77,7 +77,7 @@ pub struct DiscreteStatesContinousTimeParams {
}
impl DiscreteStatesContinousTimeParams {
pub fn init(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
pub fn new(domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
DiscreteStatesContinousTimeParams {
domain,
cim: Option::None,

@ -10,7 +10,7 @@ pub struct HillClimbing<S: ScoreFunction> {
}
impl<S: ScoreFunction> HillClimbing<S> {
pub fn init(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> {
pub fn new(score_function: S, max_parent_set: Option<usize>) -> HillClimbing<S> {
HillClimbing {
score_function,
max_parent_set,

@ -24,7 +24,7 @@ pub struct LogLikelihood {
}
impl LogLikelihood {
pub fn init(alpha: usize, tau: f64) -> LogLikelihood {
pub fn new(alpha: usize, tau: f64) -> LogLikelihood {
//Tau must be >=0.0
if tau < 0.0 {
@ -106,9 +106,9 @@ pub struct BIC {
}
impl BIC {
pub fn init(alpha: usize, tau: f64) -> BIC {
pub fn new(alpha: usize, tau: f64) -> BIC {
BIC {
ll: LogLikelihood::init(alpha, tau)
ll: LogLikelihood::new(alpha, tau)
}
}
}

@ -12,7 +12,7 @@ pub struct Trajectory {
}
impl Trajectory {
pub fn init(time: Array1<f64>, events: Array2<usize>) -> Trajectory {
pub fn new(time: Array1<f64>, events: Array2<usize>) -> Trajectory {
//Events and time are two part of the same trajectory. For this reason they must have the
//same number of sample.
if time.shape()[0] != events.shape()[0] {
@ -35,7 +35,7 @@ pub struct Dataset {
}
impl Dataset {
pub fn init(trajectories: Vec<Trajectory>) -> Dataset {
pub fn new(trajectories: Vec<Trajectory>) -> Dataset {
//All the trajectories in the same dataset must represent the same process. For this reason
//each trajectory must represent the same number of variables.
@ -178,7 +178,7 @@ pub fn trajectory_generator<T: network::Network>(
time.push(t_end.clone());
//Add the sampled trajectory to trajectories.
trajectories.push(Trajectory::init(
trajectories.push(Trajectory::new(
Array::from_vec(time),
Array2::from_shape_vec(
(events.len(), current_state.len()),
@ -188,5 +188,5 @@ pub fn trajectory_generator<T: network::Network>(
));
}
//Return a dataset object with the sampled trajectories.
Dataset::init(trajectories)
Dataset::new(trajectories)
}

@ -8,20 +8,20 @@ use rustyCTBN::ctbn::*;
#[test]
fn define_simpe_ctbn() {
let _ = CtbnNetwork::init();
let _ = CtbnNetwork::new();
assert!(true);
}
#[test]
fn add_node_to_ctbn() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
assert_eq!(String::from("n1"), net.get_node(n1).label);
}
#[test]
fn add_edge_to_ctbn() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
@ -31,7 +31,7 @@ fn add_edge_to_ctbn() {
#[test]
fn children_and_parents() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
@ -44,7 +44,7 @@ fn children_and_parents() {
#[test]
fn compute_index_ctbn() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();
@ -76,7 +76,7 @@ fn compute_index_ctbn() {
#[test]
fn compute_index_from_custom_parent_set() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let _n1 = net.add_node(generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let _n2 = net.add_node(generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
let _n3 = net.add_node(generate_discrete_time_continous_node(String::from("n3"),2)).unwrap();

@ -16,7 +16,7 @@ extern crate approx;
fn learn_binary_cim<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap();
@ -66,7 +66,7 @@ fn learn_binary_cim_BA() {
}
fn learn_ternary_cim<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
@ -121,7 +121,7 @@ fn learn_ternary_cim_BA() {
}
fn learn_ternary_cim_no_parents<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
@ -175,7 +175,7 @@ fn learn_ternary_cim_no_parents_BA() {
fn learn_mixed_discrete_cim<T: ParameterLearning> (pl: T) {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();

@ -18,18 +18,18 @@ extern crate approx;
#[test]
fn simple_score_test() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap();
let trj = Trajectory::init(
let trj = Trajectory::new(
arr1(&[0.0,0.1,0.3]),
arr2(&[[0],[1],[1]]));
let dataset = Dataset::init(vec![trj]);
let dataset = Dataset::new(vec![trj]);
let ll = LogLikelihood::init(1, 1.0);
let ll = LogLikelihood::new(1, 1.0);
assert_abs_diff_eq!(0.04257, ll.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
@ -38,17 +38,17 @@ fn simple_score_test() {
#[test]
fn simple_bic() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),2))
.unwrap();
let trj = Trajectory::init(
let trj = Trajectory::new(
arr1(&[0.0,0.1,0.3]),
arr2(&[[0],[1],[1]]));
let dataset = Dataset::init(vec![trj]);
let bic = BIC::init(1, 1.0);
let dataset = Dataset::new(vec![trj]);
let bic = BIC::new(1, 1.0);
assert_abs_diff_eq!(-0.65058, bic.call(&net, n1, &BTreeSet::new(), &dataset), epsilon=1e-3);
@ -57,7 +57,7 @@ fn simple_bic() {
fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm> (sl: T) {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
@ -86,7 +86,7 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
let data = trajectory_generator(&net, 100, 20.0, Some(6347747169756259),);
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let _n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
@ -97,13 +97,13 @@ fn check_compatibility_between_dataset_and_network<T: StructureLearningAlgorithm
#[test]
#[should_panic]
pub fn check_compatibility_between_dataset_and_network_hill_climbing() {
let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll, None);
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
check_compatibility_between_dataset_and_network(hl);
}
fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
@ -140,21 +140,21 @@ fn learn_ternary_net_2_nodes<T: StructureLearningAlgorithm> (sl: T) {
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_ll() {
let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll, None);
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_ternary_net_2_nodes(hl);
}
#[test]
pub fn learn_ternary_net_2_nodes_hill_climbing_bic() {
let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic, None);
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_ternary_net_2_nodes(hl);
}
fn get_mixed_discrete_net_3_nodes_with_data() -> (CtbnNetwork, Dataset) {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net
.add_node(generate_discrete_time_continous_node(String::from("n1"),3))
.unwrap();
@ -222,15 +222,15 @@ fn learn_mixed_discrete_net_3_nodes<T: StructureLearningAlgorithm> (sl: T) {
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll() {
let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll, None);
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, None);
learn_mixed_discrete_net_3_nodes(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic() {
let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic, None);
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, None);
learn_mixed_discrete_net_3_nodes(hl);
}
@ -247,14 +247,14 @@ fn learn_mixed_discrete_net_3_nodes_1_parent_constraint<T: StructureLearningAlgo
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_ll_1_parent_constraint() {
let ll = LogLikelihood::init(1, 1.0);
let hl = HillClimbing::init(ll, Some(1));
let ll = LogLikelihood::new(1, 1.0);
let hl = HillClimbing::new(ll, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}
#[test]
pub fn learn_mixed_discrete_net_3_nodes_hill_climbing_bic_1_parent_constraint() {
let bic = BIC::init(1, 1.0);
let hl = HillClimbing::init(bic, Some(1));
let bic = BIC::new(1, 1.0);
let hl = HillClimbing::new(bic, Some(1));
learn_mixed_discrete_net_3_nodes_1_parent_constraint(hl);
}

@ -16,7 +16,7 @@ mod utils;
#[test]
fn run_sampling() {
let mut net = CtbnNetwork::init();
let mut net = CtbnNetwork::new();
let n1 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n1"),2)).unwrap();
let n2 = net.add_node(utils::generate_discrete_time_continous_node(String::from("n2"),2)).unwrap();
net.add_edge(n1, n2);
@ -48,7 +48,7 @@ fn run_sampling() {
fn trajectory_wrong_shape() {
let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0,3]]);
Trajectory::init(time, events);
Trajectory::new(time, events);
}
@ -57,11 +57,11 @@ fn run_sampling() {
fn dataset_wrong_shape() {
let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0,3], [1,2]]);
let t1 = Trajectory::init(time, events);
let t1 = Trajectory::new(time, events);
let time = arr1(&[0.0, 0.2]);
let events = arr2(&[[0,3,3], [1,2,3]]);
let t2 = Trajectory::init(time, events);
Dataset::init(vec![t1, t2]);
let t2 = Trajectory::new(time, events);
Dataset::new(vec![t1, t2]);
}

@ -3,13 +3,13 @@ use rustyCTBN::node;
use std::collections::BTreeSet;
pub fn generate_discrete_time_continous_node(name: String, cardinality: usize) -> node::Node {
node::Node::init(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name)
node::Node::new(params::Params::DiscreteStatesContinousTime(generate_discrete_time_continous_param(cardinality)), name)
}
pub fn generate_discrete_time_continous_param(cardinality: usize) -> params::DiscreteStatesContinousTimeParams{
let domain: BTreeSet<String> = (0..cardinality).map(|x| x.to_string()).collect();
params::DiscreteStatesContinousTimeParams::init(domain)
params::DiscreteStatesContinousTimeParams::new(domain)
}

Loading…
Cancel
Save