From aa630bf9e9e2682489688fc0e8628ad7ac796549 Mon Sep 17 00:00:00 2001 From: AlessandroBregoli Date: Tue, 4 Oct 2022 12:04:20 +0200 Subject: [PATCH] Generate trajectories from python --- reCTBN/src/tools.rs | 1 + reCTBNpy/src/lib.rs | 3 ++ reCTBNpy/src/pyctbn.rs | 55 ++++++++++++++++++++++++------- reCTBNpy/src/pyparams.rs | 71 +++++++++++++++++++++++++++++++++++----- reCTBNpy/src/pytools.rs | 50 ++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 20 deletions(-) create mode 100644 reCTBNpy/src/pytools.rs diff --git a/reCTBN/src/tools.rs b/reCTBN/src/tools.rs index 70bbf76..b05ab5e 100644 --- a/reCTBN/src/tools.rs +++ b/reCTBN/src/tools.rs @@ -3,6 +3,7 @@ use ndarray::prelude::*; use crate::sampling::{ForwardSampler, Sampler}; use crate::{network, params}; +#[derive(Clone)] pub struct Trajectory { time: Array1, events: Array2, diff --git a/reCTBNpy/src/lib.rs b/reCTBNpy/src/lib.rs index b86c0e6..81d8051 100644 --- a/reCTBNpy/src/lib.rs +++ b/reCTBNpy/src/lib.rs @@ -1,6 +1,7 @@ use pyo3::prelude::*; pub mod pyctbn; pub mod pyparams; +pub mod pytools; @@ -13,6 +14,8 @@ fn reCTBNpy(py: Python, m: &PyModule) -> PyResult<()> { let params_module = PyModule::new(py, "params")?; params_module.add_class::()?; + params_module.add_class::()?; + params_module.add_class::()?; m.add_submodule(params_module)?; Ok(()) } diff --git a/reCTBNpy/src/pyctbn.rs b/reCTBNpy/src/pyctbn.rs index d7835ff..d432ce2 100644 --- a/reCTBNpy/src/pyctbn.rs +++ b/reCTBNpy/src/pyctbn.rs @@ -1,35 +1,68 @@ use std::collections::BTreeSet; +use crate::{pyparams, pytools}; use pyo3::prelude::*; -use reCTBN::{ctbn, network::Network}; +use reCTBN::{ctbn, network::Network, params, tools, params::Params}; #[pyclass] -pub struct PyCtbnNetwork { - ctbn_network: ctbn::CtbnNetwork, -} +pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork); #[pymethods] impl PyCtbnNetwork { #[new] pub fn new() -> Self { - PyCtbnNetwork { - ctbn_network: ctbn::CtbnNetwork::new(), - } + PyCtbnNetwork(ctbn::CtbnNetwork::new()) + } + + pub fn add_node(&mut self, n: pyparams::PyParams) { + self.0.add_node(n.0); } pub fn get_number_of_nodes(&self) -> usize { - self.ctbn_network.get_number_of_nodes() + self.0.get_number_of_nodes() } pub fn add_edge(&mut self, parent: usize, child: usize) { - self.ctbn_network.add_edge(parent, child); + self.0.add_edge(parent, child); + } + + pub fn get_node_indices(&self) -> BTreeSet { + self.0.get_node_indices().collect() } pub fn get_parent_set(&self, node: usize) -> BTreeSet { - self.ctbn_network.get_parent_set(node) + self.0.get_parent_set(node) } pub fn get_children_set(&self, node: usize) -> BTreeSet { - self.ctbn_network.get_children_set(node) + self.0.get_children_set(node) + } + + pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) { + match &n.0 { + Params::DiscreteStatesContinousTime(new_p) => { + if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){ + p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap(); + + } + else { + panic!("Node type mismatch") + } + } + } + } + + pub fn trajectory_generator( + &self, + n_trajectories: u64, + t_end: f64, + seed: Option, + ) -> pytools::PyDataset { + pytools::PyDataset(tools::trajectory_generator( + &self.0, + n_trajectories, + t_end, + seed, + )) } } diff --git a/reCTBNpy/src/pyparams.rs b/reCTBNpy/src/pyparams.rs index 4181ea8..378d60d 100644 --- a/reCTBNpy/src/pyparams.rs +++ b/reCTBNpy/src/pyparams.rs @@ -18,32 +18,85 @@ impl From for PyParamsError { } #[pyclass] -pub struct PyDiscreteStateContinousTime { - param: params::DiscreteStatesContinousTimeParams, +pub struct PyStateType(pub params::StateType); + +#[pyclass] +#[derive(Clone)] +pub struct PyParams(pub params::Params); + +#[pymethods] +impl PyParams { + #[staticmethod] + pub fn new_discrete_state_continous_time(p: PyDiscreteStateContinousTime) -> Self{ + PyParams(params::Params::DiscreteStatesContinousTime(p.0)) + } + + pub fn get_reserved_space_as_parent(&self) -> usize { + self.0.get_reserved_space_as_parent() + } + + pub fn get_label(&self) -> String { + self.0.get_label().to_string() + } } +/// DiscreteStatesContinousTime. +/// This represents the parameters of a classical discrete node for ctbn and it's composed by the +/// following elements: +/// - **domain**: an ordered and exhaustive set of possible states +/// - **cim**: Conditional Intensity Matrix +/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter +/// learning task and are composed by: +/// - **transitions**: number of transitions from one state to another given a specific +/// realization of the parent set +/// - **residence_time**: permanence time in each possible states given a specific +/// realization of the parent set +#[derive(Clone)] +#[pyclass] +pub struct PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams); + + #[pymethods] impl PyDiscreteStateContinousTime { #[new] pub fn new(label: String, domain: BTreeSet) -> Self { - PyDiscreteStateContinousTime { - param: params::DiscreteStatesContinousTimeParams::new(label, domain), - } + PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams::new(label, domain)) } pub fn get_cim<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3> { - match self.param.get_cim() { + match self.0.get_cim() { Some(x) => Some(x.to_pyarray(py)), None => None, } } pub fn set_cim<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3) -> Result<(), PyParamsError> { - self.param.set_cim(cim.as_array().to_owned())?; + self.0.set_cim(cim.as_array().to_owned())?; Ok(()) } - pub fn get_label(&self) -> String { - self.param.get_label().to_string() + + pub fn get_transitions<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3> { + match self.0.get_transitions() { + Some(x) => Some(x.to_pyarray(py)), + None => None, + } } + + pub fn set_transitions<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3){ + self.0.set_transitions(cim.as_array().to_owned()); + } + + + pub fn get_residence_time<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray2> { + match self.0.get_residence_time() { + Some(x) => Some(x.to_pyarray(py)), + None => None, + } + } + + pub fn set_residence_time<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray2) { + self.0.set_residence_time(cim.as_array().to_owned()); + } + } diff --git a/reCTBNpy/src/pytools.rs b/reCTBNpy/src/pytools.rs new file mode 100644 index 0000000..2c4fd32 --- /dev/null +++ b/reCTBNpy/src/pytools.rs @@ -0,0 +1,50 @@ +use numpy::{self, ToPyArray}; +use pyo3::{exceptions::PyValueError, prelude::*}; +use reCTBN::{tools, network}; + +#[pyclass] +#[derive(Clone)] +pub struct PyTrajectory(pub tools::Trajectory); + +#[pymethods] +impl PyTrajectory { + #[new] + pub fn new( + time: numpy::PyReadonlyArray1, + events: numpy::PyReadonlyArray2, + ) -> PyTrajectory { + PyTrajectory(tools::Trajectory::new( + time.as_array().to_owned(), + events.as_array().to_owned(), + )) + } + + pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1 { + self.0.get_time().to_pyarray(py) + } + + pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2 { + self.0.get_events().to_pyarray(py) + } +} + +#[pyclass] +pub struct PyDataset(pub tools::Dataset); + +#[pymethods] +impl PyDataset { + #[new] + pub fn new(trajectories: Vec) -> PyDataset { + PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect())) + } + + pub fn get_number_of_trajectories(&self) -> usize { + self.0.get_trajectories().len() + } + + pub fn get_trajectory(&self, idx: usize) -> PyTrajectory { + PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone()) + } + +} +