From b870da3a16ae2ac37d3b7b95c39034b5d9f0d1a2 Mon Sep 17 00:00:00 2001 From: meliurwen Date: Wed, 15 Feb 2023 09:34:22 +0100 Subject: [PATCH] First working version --- .gitignore | 2 ++ .gitmodules | 4 +++ Cargo.toml | 8 +++++ LICENSE | 21 +++++++++++++ deps/reCTBN | 1 + rust-toolchain.toml | 7 +++++ rustfmt.toml | 39 ++++++++++++++++++++++++ src/main.rs | 72 +++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 154 insertions(+) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 Cargo.toml create mode 100644 LICENSE create mode 160000 deps/reCTBN create mode 100644 rust-toolchain.toml create mode 100644 rustfmt.toml create mode 100644 src/main.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..08f0ba0 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "deps/reCTBN"] + path = deps/reCTBN + url = ../reCTBN.git + branch = dev diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b7462e4 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "rectbn-benchmarks" +version = "0.1.0" +edition = "2021" + +[dependencies] +reCTBN = { path = "deps/reCTBN/reCTBN", package = "reCTBN", version="0.1.0" } +json = "0.12.*" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..84c256c --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Meliurwen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/deps/reCTBN b/deps/reCTBN new file mode 160000 index 0000000..e638a62 --- /dev/null +++ b/deps/reCTBN @@ -0,0 +1 @@ +Subproject commit e638a627bb1efb675d4242eff0bb543715b55ddc diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..367bc0b --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,7 @@ +# This file defines the Rust toolchain to use when a command is executed. +# See also https://rust-lang.github.io/rustup/overrides.html + +[toolchain] +channel = "stable" +components = [ "clippy", "rustfmt" ] +profile = "minimal" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..b6f1257 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,39 @@ +# This file defines the Rust style for automatic reformatting. +# See also https://rust-lang.github.io/rustfmt + +# NOTE: the unstable options will be uncommented when stabilized. + +# Version of the formatting rules to use. +#version = "One" + +# Number of spaces per tab. +tab_spaces = 4 + +max_width = 100 +#comment_width = 80 + +# Prevent carriage returns, admitted only \n. +newline_style = "Unix" + +# The "Default" setting has a heuristic which can split lines too aggresively. +#use_small_heuristics = "Max" + +# How imports should be grouped into `use` statements. +#imports_granularity = "Module" + +# How consecutive imports are grouped together. +#group_imports = "StdExternalCrate" + +# Error if unable to get all lines within max_width, except for comments and +# string literals. +#error_on_line_overflow = true + +# Error if unable to get comments or string literals within max_width, or they +# are left with trailing whitespaces. +#error_on_unformatted = true + +# Files to ignore like third party code which is formatted upstream. +# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors +ignore = [ + "tests/" +] diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..5003247 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,72 @@ +#![allow(non_snake_case)] + +use std::collections::BTreeSet; + +use reCTBN::parameter_learning::MLE; +use reCTBN::params::DiscreteStatesContinousTimeParams; +use reCTBN::params::Params::DiscreteStatesContinousTime; +use reCTBN::process::ctbn::CtbnNetwork; +use reCTBN::process::NetworkProcess; +use reCTBN::structure_learning::constraint_based_algorithm::CTPC; +use reCTBN::structure_learning::hypothesis_test::{ChiSquare, F}; +use reCTBN::structure_learning::StructureLearningAlgorithm; +use reCTBN::tools::trajectory_generator; +use reCTBN::tools::Dataset; +use reCTBN::tools::RandomGraphGenerator; +use reCTBN::tools::RandomParametersGenerator; +use reCTBN::tools::UniformGraphGenerator; +use reCTBN::tools::UniformParametersGenerator; + +fn uniform_parameters_generator_right_densities_ctmp() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + let nodes_cardinality = 20; + let domain_cardinality = 3; + for node in 0..nodes_cardinality { + // Create the domain for a discrete node + let mut domain = BTreeSet::new(); + for dvalue in 0..domain_cardinality { + domain.insert(dvalue.to_string()); + } + // Create the parameters for a discrete node using the domain + let param = DiscreteStatesContinousTimeParams::new(node.to_string(), domain); + //Create the node using the parameters + let node = DiscreteStatesContinousTime(param); + // Add the node to the network + net.add_node(node).unwrap(); + } + + // Initialize the Graph Generator using the one with an + // uniform distribution + let mut structure_generator = UniformGraphGenerator::new(1.0 / 3.0, Some(7641630759785120)); + + // Generate the graph directly on the network + structure_generator.generate_graph(&mut net); + + // Initialize the parameters generator with uniform distributin + let mut cim_generator = UniformParametersGenerator::new(3.0..7.0, Some(7641630759785120)); + + // Generate CIMs with uniformly distributed parameters. + cim_generator.generate_parameters(&mut net); + + let dataset = trajectory_generator(&net, 300, 200.0, Some(30230423)); + + return (net, dataset); +} + +fn structure_learning_CTPC(net: CtbnNetwork, dataset: &Dataset) { + // Initialize the hypothesis tests to pass to the CTPC with their + // respective significance level `alpha` + let f = F::new(1e-6); + let chi_sq = ChiSquare::new(1e-4); + // Use the bayesian approach to learn the parameters + let parameter_learning = MLE {}; + //Initialize CTPC + let ctpc = CTPC::new(parameter_learning, f, chi_sq); + // Learn the structure of the network from the generated trajectory + ctpc.fit_transform(net, dataset); +} + +fn main() { + let (net, dataset) = uniform_parameters_generator_right_densities_ctmp(); + structure_learning_CTPC(net, &dataset); +}