parent
3dae67a80c
commit
aa630bf9e9
@ -1,35 +1,68 @@ |
|||||||
use std::collections::BTreeSet; |
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use crate::{pyparams, pytools}; |
||||||
use pyo3::prelude::*; |
use pyo3::prelude::*; |
||||||
use reCTBN::{ctbn, network::Network}; |
use reCTBN::{ctbn, network::Network, params, tools, params::Params}; |
||||||
|
|
||||||
#[pyclass] |
#[pyclass] |
||||||
pub struct PyCtbnNetwork { |
pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork); |
||||||
ctbn_network: ctbn::CtbnNetwork, |
|
||||||
} |
|
||||||
|
|
||||||
#[pymethods] |
#[pymethods] |
||||||
impl PyCtbnNetwork { |
impl PyCtbnNetwork { |
||||||
#[new] |
#[new] |
||||||
pub fn new() -> Self { |
pub fn new() -> Self { |
||||||
PyCtbnNetwork { |
PyCtbnNetwork(ctbn::CtbnNetwork::new()) |
||||||
ctbn_network: ctbn::CtbnNetwork::new(), |
|
||||||
} |
} |
||||||
|
|
||||||
|
pub fn add_node(&mut self, n: pyparams::PyParams) { |
||||||
|
self.0.add_node(n.0); |
||||||
} |
} |
||||||
|
|
||||||
pub fn get_number_of_nodes(&self) -> usize { |
pub fn get_number_of_nodes(&self) -> usize { |
||||||
self.ctbn_network.get_number_of_nodes() |
self.0.get_number_of_nodes() |
||||||
} |
} |
||||||
|
|
||||||
pub fn add_edge(&mut self, parent: usize, child: usize) { |
pub fn add_edge(&mut self, parent: usize, child: usize) { |
||||||
self.ctbn_network.add_edge(parent, child); |
self.0.add_edge(parent, child); |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_node_indices(&self) -> BTreeSet<usize> { |
||||||
|
self.0.get_node_indices().collect() |
||||||
} |
} |
||||||
|
|
||||||
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
self.ctbn_network.get_parent_set(node) |
self.0.get_parent_set(node) |
||||||
} |
} |
||||||
|
|
||||||
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
self.ctbn_network.get_children_set(node) |
self.0.get_children_set(node) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) { |
||||||
|
match &n.0 { |
||||||
|
Params::DiscreteStatesContinousTime(new_p) => { |
||||||
|
if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){ |
||||||
|
p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap(); |
||||||
|
|
||||||
|
} |
||||||
|
else { |
||||||
|
panic!("Node type mismatch") |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn trajectory_generator( |
||||||
|
&self, |
||||||
|
n_trajectories: u64, |
||||||
|
t_end: f64, |
||||||
|
seed: Option<u64>, |
||||||
|
) -> pytools::PyDataset { |
||||||
|
pytools::PyDataset(tools::trajectory_generator( |
||||||
|
&self.0, |
||||||
|
n_trajectories, |
||||||
|
t_end, |
||||||
|
seed, |
||||||
|
)) |
||||||
} |
} |
||||||
} |
} |
||||||
|
@ -0,0 +1,50 @@ |
|||||||
|
use numpy::{self, ToPyArray}; |
||||||
|
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||||
|
use reCTBN::{tools, network}; |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
#[derive(Clone)] |
||||||
|
pub struct PyTrajectory(pub tools::Trajectory); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyTrajectory { |
||||||
|
#[new] |
||||||
|
pub fn new( |
||||||
|
time: numpy::PyReadonlyArray1<f64>, |
||||||
|
events: numpy::PyReadonlyArray2<usize>, |
||||||
|
) -> PyTrajectory { |
||||||
|
PyTrajectory(tools::Trajectory::new( |
||||||
|
time.as_array().to_owned(), |
||||||
|
events.as_array().to_owned(), |
||||||
|
)) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1<f64> { |
||||||
|
self.0.get_time().to_pyarray(py) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2<usize> { |
||||||
|
self.0.get_events().to_pyarray(py) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
pub struct PyDataset(pub tools::Dataset); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyDataset { |
||||||
|
#[new] |
||||||
|
pub fn new(trajectories: Vec<PyTrajectory>) -> PyDataset { |
||||||
|
PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect())) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_number_of_trajectories(&self) -> usize { |
||||||
|
self.0.get_trajectories().len() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_trajectory(&self, idx: usize) -> PyTrajectory { |
||||||
|
PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone()) |
||||||
|
} |
||||||
|
|
||||||
|
} |
||||||
|
|
Loading…
Reference in new issue