parent
3dae67a80c
commit
aa630bf9e9
@ -1,35 +1,68 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use crate::{pyparams, pytools}; |
||||
use pyo3::prelude::*; |
||||
use reCTBN::{ctbn, network::Network}; |
||||
use reCTBN::{ctbn, network::Network, params, tools, params::Params}; |
||||
|
||||
#[pyclass] |
||||
pub struct PyCtbnNetwork { |
||||
ctbn_network: ctbn::CtbnNetwork, |
||||
} |
||||
pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork); |
||||
|
||||
#[pymethods] |
||||
impl PyCtbnNetwork { |
||||
#[new] |
||||
pub fn new() -> Self { |
||||
PyCtbnNetwork { |
||||
ctbn_network: ctbn::CtbnNetwork::new(), |
||||
} |
||||
PyCtbnNetwork(ctbn::CtbnNetwork::new()) |
||||
} |
||||
|
||||
pub fn add_node(&mut self, n: pyparams::PyParams) { |
||||
self.0.add_node(n.0); |
||||
} |
||||
|
||||
pub fn get_number_of_nodes(&self) -> usize { |
||||
self.ctbn_network.get_number_of_nodes() |
||||
self.0.get_number_of_nodes() |
||||
} |
||||
|
||||
pub fn add_edge(&mut self, parent: usize, child: usize) { |
||||
self.ctbn_network.add_edge(parent, child); |
||||
self.0.add_edge(parent, child); |
||||
} |
||||
|
||||
pub fn get_node_indices(&self) -> BTreeSet<usize> { |
||||
self.0.get_node_indices().collect() |
||||
} |
||||
|
||||
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.ctbn_network.get_parent_set(node) |
||||
self.0.get_parent_set(node) |
||||
} |
||||
|
||||
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.ctbn_network.get_children_set(node) |
||||
self.0.get_children_set(node) |
||||
} |
||||
|
||||
pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) { |
||||
match &n.0 { |
||||
Params::DiscreteStatesContinousTime(new_p) => { |
||||
if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){ |
||||
p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap(); |
||||
|
||||
} |
||||
else { |
||||
panic!("Node type mismatch") |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub fn trajectory_generator( |
||||
&self, |
||||
n_trajectories: u64, |
||||
t_end: f64, |
||||
seed: Option<u64>, |
||||
) -> pytools::PyDataset { |
||||
pytools::PyDataset(tools::trajectory_generator( |
||||
&self.0, |
||||
n_trajectories, |
||||
t_end, |
||||
seed, |
||||
)) |
||||
} |
||||
} |
||||
|
@ -0,0 +1,50 @@ |
||||
use numpy::{self, ToPyArray}; |
||||
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||
use reCTBN::{tools, network}; |
||||
|
||||
#[pyclass] |
||||
#[derive(Clone)] |
||||
pub struct PyTrajectory(pub tools::Trajectory); |
||||
|
||||
#[pymethods] |
||||
impl PyTrajectory { |
||||
#[new] |
||||
pub fn new( |
||||
time: numpy::PyReadonlyArray1<f64>, |
||||
events: numpy::PyReadonlyArray2<usize>, |
||||
) -> PyTrajectory { |
||||
PyTrajectory(tools::Trajectory::new( |
||||
time.as_array().to_owned(), |
||||
events.as_array().to_owned(), |
||||
)) |
||||
} |
||||
|
||||
pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1<f64> { |
||||
self.0.get_time().to_pyarray(py) |
||||
} |
||||
|
||||
pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2<usize> { |
||||
self.0.get_events().to_pyarray(py) |
||||
} |
||||
} |
||||
|
||||
#[pyclass] |
||||
pub struct PyDataset(pub tools::Dataset); |
||||
|
||||
#[pymethods] |
||||
impl PyDataset { |
||||
#[new] |
||||
pub fn new(trajectories: Vec<PyTrajectory>) -> PyDataset { |
||||
PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect())) |
||||
} |
||||
|
||||
pub fn get_number_of_trajectories(&self) -> usize { |
||||
self.0.get_trajectories().len() |
||||
} |
||||
|
||||
pub fn get_trajectory(&self, idx: usize) -> PyTrajectory { |
||||
PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone()) |
||||
} |
||||
|
||||
} |
||||
|
Loading…
Reference in new issue