commit b3dd9133afba625aec6afe39049f66e15cdf11a6 Author: meliurwen Date: Tue Feb 7 11:49:03 2023 +0100 First working version diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c640ca5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +.vscode diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d790c02 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "reCTBN"] + path = deps/reCTBN + url = ../reCTBN.git + branch = dev diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c7b74a7 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "exporter" +version = "0.1.0" +edition = "2021" + +[dependencies] +reCTBN = { path = "deps/reCTBN/reCTBN", package = "reCTBN", version="0.1.0" } +json = "0.12.*" diff --git a/deps/reCTBN b/deps/reCTBN new file mode 160000 index 0000000..49c2c55 --- /dev/null +++ b/deps/reCTBN @@ -0,0 +1 @@ +Subproject commit 49c2c55f613996574e4f901cc53a3d06ebac6311 diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..367bc0b --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,7 @@ +# This file defines the Rust toolchain to use when a command is executed. +# See also https://rust-lang.github.io/rustup/overrides.html + +[toolchain] +channel = "stable" +components = [ "clippy", "rustfmt" ] +profile = "minimal" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..b6f1257 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,39 @@ +# This file defines the Rust style for automatic reformatting. +# See also https://rust-lang.github.io/rustfmt + +# NOTE: the unstable options will be uncommented when stabilized. + +# Version of the formatting rules to use. +#version = "One" + +# Number of spaces per tab. +tab_spaces = 4 + +max_width = 100 +#comment_width = 80 + +# Prevent carriage returns, admitted only \n. +newline_style = "Unix" + +# The "Default" setting has a heuristic which can split lines too aggresively. +#use_small_heuristics = "Max" + +# How imports should be grouped into `use` statements. +#imports_granularity = "Module" + +# How consecutive imports are grouped together. +#group_imports = "StdExternalCrate" + +# Error if unable to get all lines within max_width, except for comments and +# string literals. +#error_on_line_overflow = true + +# Error if unable to get comments or string literals within max_width, or they +# are left with trailing whitespaces. +#error_on_unformatted = true + +# Files to ignore like third party code which is formatted upstream. +# Ignoring tests is a temporary measure due some issues regarding rank-3 tensors +ignore = [ + "tests/" +] diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..45917fb --- /dev/null +++ b/src/main.rs @@ -0,0 +1,96 @@ +use reCTBN; + +use json::*; +use reCTBN::params; +use reCTBN::params::Params::DiscreteStatesContinousTime; +use reCTBN::params::ParamsTrait; +use reCTBN::process::ctbn::CtbnNetwork; +use reCTBN::process::NetworkProcess; +use reCTBN::tools::trajectory_generator; +use reCTBN::tools::Dataset; +use reCTBN::tools::RandomGraphGenerator; +use reCTBN::tools::RandomParametersGenerator; +use reCTBN::tools::UniformGraphGenerator; +use reCTBN::tools::UniformParametersGenerator; +use std::collections::BTreeSet; + +fn uniform_parameters_generator_right_densities_ctmp() -> (CtbnNetwork, Dataset) { + let mut net = CtbnNetwork::new(); + let nodes_cardinality = 3; + let domain_cardinality = 2; + for node in 0..nodes_cardinality { + // Create the domain for a discrete node + let mut domain = BTreeSet::new(); + for dvalue in 0..domain_cardinality { + domain.insert(dvalue.to_string()); + } + // Create the parameters for a discrete node using the domain + let param = params::DiscreteStatesContinousTimeParams::new(node.to_string(), domain); + //Create the node using the parameters + let node = DiscreteStatesContinousTime(param); + // Add the node to the network + net.add_node(node).unwrap(); + } + + // Initialize the Graph Generator using the one with an + // uniform distribution + let mut structure_generator = UniformGraphGenerator::new(1.0 / 3.0, Some(7641630759785120)); + + // Generate the graph directly on the network + structure_generator.generate_graph(&mut net); + + // Initialize the parameters generator with uniform distributin + let mut cim_generator = UniformParametersGenerator::new(3.0..7.0, Some(7641630759785120)); + + // Generate CIMs with uniformly distributed parameters. + cim_generator.generate_parameters(&mut net); + + let dataset = trajectory_generator(&net, 3, 20.0, Some(30230423)); + + return (net, dataset); +} + +fn main() { + println!("Hello, world!"); + let mut data = json::JsonValue::new_array(); + data.push(json::JsonValue::new_object()).unwrap(); + data[0]["dyn.str"] = json::JsonValue::new_array(); + data[0]["variables"] = json::JsonValue::new_array(); + data[0]["dyn.cims"] = object! {}; + data[0]["samples"] = json::JsonValue::new_array(); + let (net, dataset) = uniform_parameters_generator_right_densities_ctmp(); + + for node_idx in net.get_node_indices() { + let mut variable = json::JsonValue::new_object(); + variable["Name"] = json::JsonValue::String(net.get_node(node_idx).get_label().to_string()); + variable["Value"] = + json::JsonValue::Number(net.get_node(node_idx).get_reserved_space_as_parent().into()); + data[0]["variables"].push(variable).unwrap(); + for parent_idx in net.get_parent_set(node_idx) { + let mut edge = json::JsonValue::new_object(); + edge["From"] = + json::JsonValue::String(net.get_node(parent_idx).get_label().to_string()); + edge["To"] = json::JsonValue::String(net.get_node(node_idx).get_label().to_string()); + data[0]["dyn.str"].push(edge).unwrap(); + } + } + let nodes: Vec = net + .get_node_indices() + .into_iter() + .map(|x| net.get_node(x).get_label().to_string()) + .collect(); + for sample in dataset.get_trajectories() { + let mut trajectory = json::JsonValue::new_array(); + for event_idx in 0..sample.get_time().shape()[0] { + let mut event = json::JsonValue::new_object(); + event["Time"] = json::JsonValue::Number(sample.get_time()[event_idx].into()); + for (node_idx, node_label) in nodes.iter().enumerate() { + event[node_label] = + json::JsonValue::Number(sample.get_events()[[event_idx, node_idx]].into()); + } + trajectory.push(event).unwrap(); + } + data[0]["samples"].push(trajectory).unwrap(); + } + println!("{}", data.dump()); +}