Generate trajectories from python

62-feature-rectbnpy
AlessandroBregoli 2 years ago
parent 3dae67a80c
commit aa630bf9e9
  1. 1
      reCTBN/src/tools.rs
  2. 3
      reCTBNpy/src/lib.rs
  3. 55
      reCTBNpy/src/pyctbn.rs
  4. 71
      reCTBNpy/src/pyparams.rs
  5. 50
      reCTBNpy/src/pytools.rs

@ -3,6 +3,7 @@ use ndarray::prelude::*;
use crate::sampling::{ForwardSampler, Sampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{network, params}; use crate::{network, params};
#[derive(Clone)]
pub struct Trajectory { pub struct Trajectory {
time: Array1<f64>, time: Array1<f64>,
events: Array2<usize>, events: Array2<usize>,

@ -1,6 +1,7 @@
use pyo3::prelude::*; use pyo3::prelude::*;
pub mod pyctbn; pub mod pyctbn;
pub mod pyparams; pub mod pyparams;
pub mod pytools;
@ -13,6 +14,8 @@ fn reCTBNpy(py: Python, m: &PyModule) -> PyResult<()> {
let params_module = PyModule::new(py, "params")?; let params_module = PyModule::new(py, "params")?;
params_module.add_class::<pyparams::PyDiscreteStateContinousTime>()?; params_module.add_class::<pyparams::PyDiscreteStateContinousTime>()?;
params_module.add_class::<pyparams::PyStateType>()?;
params_module.add_class::<pyparams::PyParams>()?;
m.add_submodule(params_module)?; m.add_submodule(params_module)?;
Ok(()) Ok(())
} }

@ -1,35 +1,68 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use crate::{pyparams, pytools};
use pyo3::prelude::*; use pyo3::prelude::*;
use reCTBN::{ctbn, network::Network}; use reCTBN::{ctbn, network::Network, params, tools, params::Params};
#[pyclass] #[pyclass]
pub struct PyCtbnNetwork { pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork);
ctbn_network: ctbn::CtbnNetwork,
}
#[pymethods] #[pymethods]
impl PyCtbnNetwork { impl PyCtbnNetwork {
#[new] #[new]
pub fn new() -> Self { pub fn new() -> Self {
PyCtbnNetwork { PyCtbnNetwork(ctbn::CtbnNetwork::new())
ctbn_network: ctbn::CtbnNetwork::new(), }
}
pub fn add_node(&mut self, n: pyparams::PyParams) {
self.0.add_node(n.0);
} }
pub fn get_number_of_nodes(&self) -> usize { pub fn get_number_of_nodes(&self) -> usize {
self.ctbn_network.get_number_of_nodes() self.0.get_number_of_nodes()
} }
pub fn add_edge(&mut self, parent: usize, child: usize) { pub fn add_edge(&mut self, parent: usize, child: usize) {
self.ctbn_network.add_edge(parent, child); self.0.add_edge(parent, child);
}
pub fn get_node_indices(&self) -> BTreeSet<usize> {
self.0.get_node_indices().collect()
} }
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.ctbn_network.get_parent_set(node) self.0.get_parent_set(node)
} }
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> { pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.ctbn_network.get_children_set(node) self.0.get_children_set(node)
}
pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) {
match &n.0 {
Params::DiscreteStatesContinousTime(new_p) => {
if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){
p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap();
}
else {
panic!("Node type mismatch")
}
}
}
}
pub fn trajectory_generator(
&self,
n_trajectories: u64,
t_end: f64,
seed: Option<u64>,
) -> pytools::PyDataset {
pytools::PyDataset(tools::trajectory_generator(
&self.0,
n_trajectories,
t_end,
seed,
))
} }
} }

@ -18,32 +18,85 @@ impl From<params::ParamsError> for PyParamsError {
} }
#[pyclass] #[pyclass]
pub struct PyDiscreteStateContinousTime { pub struct PyStateType(pub params::StateType);
param: params::DiscreteStatesContinousTimeParams,
#[pyclass]
#[derive(Clone)]
pub struct PyParams(pub params::Params);
#[pymethods]
impl PyParams {
#[staticmethod]
pub fn new_discrete_state_continous_time(p: PyDiscreteStateContinousTime) -> Self{
PyParams(params::Params::DiscreteStatesContinousTime(p.0))
}
pub fn get_reserved_space_as_parent(&self) -> usize {
self.0.get_reserved_space_as_parent()
}
pub fn get_label(&self) -> String {
self.0.get_label().to_string()
}
} }
/// DiscreteStatesContinousTime.
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
/// following elements:
/// - **domain**: an ordered and exhaustive set of possible states
/// - **cim**: Conditional Intensity Matrix
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
/// learning task and are composed by:
/// - **transitions**: number of transitions from one state to another given a specific
/// realization of the parent set
/// - **residence_time**: permanence time in each possible states given a specific
/// realization of the parent set
#[derive(Clone)]
#[pyclass]
pub struct PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams);
#[pymethods] #[pymethods]
impl PyDiscreteStateContinousTime { impl PyDiscreteStateContinousTime {
#[new] #[new]
pub fn new(label: String, domain: BTreeSet<String>) -> Self { pub fn new(label: String, domain: BTreeSet<String>) -> Self {
PyDiscreteStateContinousTime { PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams::new(label, domain))
param: params::DiscreteStatesContinousTimeParams::new(label, domain),
}
} }
pub fn get_cim<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<f64>> { pub fn get_cim<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<f64>> {
match self.param.get_cim() { match self.0.get_cim() {
Some(x) => Some(x.to_pyarray(py)), Some(x) => Some(x.to_pyarray(py)),
None => None, None => None,
} }
} }
pub fn set_cim<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<f64>) -> Result<(), PyParamsError> { pub fn set_cim<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<f64>) -> Result<(), PyParamsError> {
self.param.set_cim(cim.as_array().to_owned())?; self.0.set_cim(cim.as_array().to_owned())?;
Ok(()) Ok(())
} }
pub fn get_label(&self) -> String {
self.param.get_label().to_string() pub fn get_transitions<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<usize>> {
match self.0.get_transitions() {
Some(x) => Some(x.to_pyarray(py)),
None => None,
}
} }
pub fn set_transitions<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<usize>){
self.0.set_transitions(cim.as_array().to_owned());
}
pub fn get_residence_time<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray2<f64>> {
match self.0.get_residence_time() {
Some(x) => Some(x.to_pyarray(py)),
None => None,
}
}
pub fn set_residence_time<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray2<f64>) {
self.0.set_residence_time(cim.as_array().to_owned());
}
} }

@ -0,0 +1,50 @@
use numpy::{self, ToPyArray};
use pyo3::{exceptions::PyValueError, prelude::*};
use reCTBN::{tools, network};
#[pyclass]
#[derive(Clone)]
pub struct PyTrajectory(pub tools::Trajectory);
#[pymethods]
impl PyTrajectory {
#[new]
pub fn new(
time: numpy::PyReadonlyArray1<f64>,
events: numpy::PyReadonlyArray2<usize>,
) -> PyTrajectory {
PyTrajectory(tools::Trajectory::new(
time.as_array().to_owned(),
events.as_array().to_owned(),
))
}
pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1<f64> {
self.0.get_time().to_pyarray(py)
}
pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2<usize> {
self.0.get_events().to_pyarray(py)
}
}
#[pyclass]
pub struct PyDataset(pub tools::Dataset);
#[pymethods]
impl PyDataset {
#[new]
pub fn new(trajectories: Vec<PyTrajectory>) -> PyDataset {
PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect()))
}
pub fn get_number_of_trajectories(&self) -> usize {
self.0.get_trajectories().len()
}
pub fn get_trajectory(&self, idx: usize) -> PyTrajectory {
PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone())
}
}
Loading…
Cancel
Save