Compare commits
3 Commits
dev
...
62-feature
Author | SHA1 | Date |
---|---|---|
AlessandroBregoli | e1af14b620 | 2 years ago |
AlessandroBregoli | aa630bf9e9 | 2 years ago |
AlessandroBregoli | 3dae67a80c | 2 years ago |
@ -1,3 +1,4 @@ |
|||||||
/target |
/target |
||||||
Cargo.lock |
Cargo.lock |
||||||
.vscode |
.vscode |
||||||
|
poetry.lock |
||||||
|
@ -0,0 +1,16 @@ |
|||||||
|
[tool.poetry] |
||||||
|
name = "rectbnpy" |
||||||
|
version = "0.1.0" |
||||||
|
description = "" |
||||||
|
authors = ["AlessandroBregoli <alessandroxciv@gmail.com>"] |
||||||
|
readme = "README.md" |
||||||
|
|
||||||
|
[tool.poetry.dependencies] |
||||||
|
python = "^3.10" |
||||||
|
maturin = "^0.13.3" |
||||||
|
numpy = "^1.23.3" |
||||||
|
|
||||||
|
|
||||||
|
[build-system] |
||||||
|
requires = ["poetry-core"] |
||||||
|
build-backend = "poetry.core.masonry.api" |
@ -0,0 +1,69 @@ |
|||||||
|
name: CI |
||||||
|
|
||||||
|
on: |
||||||
|
push: |
||||||
|
branches: |
||||||
|
- main |
||||||
|
- master |
||||||
|
pull_request: |
||||||
|
|
||||||
|
jobs: |
||||||
|
linux: |
||||||
|
runs-on: ubuntu-latest |
||||||
|
steps: |
||||||
|
- uses: actions/checkout@v3 |
||||||
|
- uses: messense/maturin-action@v1 |
||||||
|
with: |
||||||
|
manylinux: auto |
||||||
|
command: build |
||||||
|
args: --release --sdist -o dist --find-interpreter |
||||||
|
- name: Upload wheels |
||||||
|
uses: actions/upload-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
path: dist |
||||||
|
|
||||||
|
windows: |
||||||
|
runs-on: windows-latest |
||||||
|
steps: |
||||||
|
- uses: actions/checkout@v3 |
||||||
|
- uses: messense/maturin-action@v1 |
||||||
|
with: |
||||||
|
command: build |
||||||
|
args: --release -o dist --find-interpreter |
||||||
|
- name: Upload wheels |
||||||
|
uses: actions/upload-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
path: dist |
||||||
|
|
||||||
|
macos: |
||||||
|
runs-on: macos-latest |
||||||
|
steps: |
||||||
|
- uses: actions/checkout@v3 |
||||||
|
- uses: messense/maturin-action@v1 |
||||||
|
with: |
||||||
|
command: build |
||||||
|
args: --release -o dist --universal2 --find-interpreter |
||||||
|
- name: Upload wheels |
||||||
|
uses: actions/upload-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
path: dist |
||||||
|
|
||||||
|
release: |
||||||
|
name: Release |
||||||
|
runs-on: ubuntu-latest |
||||||
|
if: "startsWith(github.ref, 'refs/tags/')" |
||||||
|
needs: [ macos, windows, linux ] |
||||||
|
steps: |
||||||
|
- uses: actions/download-artifact@v2 |
||||||
|
with: |
||||||
|
name: wheels |
||||||
|
- name: Publish to PyPI |
||||||
|
uses: messense/maturin-action@v1 |
||||||
|
env: |
||||||
|
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} |
||||||
|
with: |
||||||
|
command: upload |
||||||
|
args: --skip-existing * |
@ -0,0 +1,72 @@ |
|||||||
|
/target |
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files |
||||||
|
__pycache__/ |
||||||
|
.pytest_cache/ |
||||||
|
*.py[cod] |
||||||
|
|
||||||
|
# C extensions |
||||||
|
*.so |
||||||
|
|
||||||
|
# Distribution / packaging |
||||||
|
.Python |
||||||
|
.venv/ |
||||||
|
env/ |
||||||
|
bin/ |
||||||
|
build/ |
||||||
|
develop-eggs/ |
||||||
|
dist/ |
||||||
|
eggs/ |
||||||
|
lib/ |
||||||
|
lib64/ |
||||||
|
parts/ |
||||||
|
sdist/ |
||||||
|
var/ |
||||||
|
include/ |
||||||
|
man/ |
||||||
|
venv/ |
||||||
|
*.egg-info/ |
||||||
|
.installed.cfg |
||||||
|
*.egg |
||||||
|
|
||||||
|
# Installer logs |
||||||
|
pip-log.txt |
||||||
|
pip-delete-this-directory.txt |
||||||
|
pip-selfcheck.json |
||||||
|
|
||||||
|
# Unit test / coverage reports |
||||||
|
htmlcov/ |
||||||
|
.tox/ |
||||||
|
.coverage |
||||||
|
.cache |
||||||
|
nosetests.xml |
||||||
|
coverage.xml |
||||||
|
|
||||||
|
# Translations |
||||||
|
*.mo |
||||||
|
|
||||||
|
# Mr Developer |
||||||
|
.mr.developer.cfg |
||||||
|
.project |
||||||
|
.pydevproject |
||||||
|
|
||||||
|
# Rope |
||||||
|
.ropeproject |
||||||
|
|
||||||
|
# Django stuff: |
||||||
|
*.log |
||||||
|
*.pot |
||||||
|
|
||||||
|
.DS_Store |
||||||
|
|
||||||
|
# Sphinx documentation |
||||||
|
docs/_build/ |
||||||
|
|
||||||
|
# PyCharm |
||||||
|
.idea/ |
||||||
|
|
||||||
|
# VSCode |
||||||
|
.vscode/ |
||||||
|
|
||||||
|
# Pyenv |
||||||
|
.python-version |
@ -0,0 +1,15 @@ |
|||||||
|
[package] |
||||||
|
name = "reCTBNpy" |
||||||
|
version = "0.1.0" |
||||||
|
edition = "2021" |
||||||
|
|
||||||
|
|
||||||
|
[lib] |
||||||
|
crate-type = ["cdylib"] |
||||||
|
|
||||||
|
[workspace] |
||||||
|
|
||||||
|
[dependencies] |
||||||
|
pyo3 = { version = "0.17.1", features = ["extension-module"] } |
||||||
|
numpy = "*" |
||||||
|
reCTBN = { path="../reCTBN" } |
@ -0,0 +1,14 @@ |
|||||||
|
[build-system] |
||||||
|
requires = ["maturin>=0.13,<0.14"] |
||||||
|
build-backend = "maturin" |
||||||
|
|
||||||
|
[project] |
||||||
|
name = "reCTBNpy" |
||||||
|
requires-python = ">=3.7" |
||||||
|
classifiers = [ |
||||||
|
"Programming Language :: Rust", |
||||||
|
"Programming Language :: Python :: Implementation :: CPython", |
||||||
|
"Programming Language :: Python :: Implementation :: PyPy", |
||||||
|
] |
||||||
|
|
||||||
|
|
@ -0,0 +1,21 @@ |
|||||||
|
use pyo3::prelude::*; |
||||||
|
pub mod pyctbn; |
||||||
|
pub mod pyparams; |
||||||
|
pub mod pytools; |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/// A Python module implemented in Rust.
|
||||||
|
#[pymodule] |
||||||
|
fn reCTBNpy(py: Python, m: &PyModule) -> PyResult<()> { |
||||||
|
let network_module = PyModule::new(py, "network")?; |
||||||
|
network_module.add_class::<pyctbn::PyCtbnNetwork>()?; |
||||||
|
m.add_submodule(network_module)?; |
||||||
|
|
||||||
|
let params_module = PyModule::new(py, "params")?; |
||||||
|
params_module.add_class::<pyparams::PyDiscreteStateContinousTime>()?; |
||||||
|
params_module.add_class::<pyparams::PyStateType>()?; |
||||||
|
params_module.add_class::<pyparams::PyParams>()?; |
||||||
|
m.add_submodule(params_module)?; |
||||||
|
Ok(()) |
||||||
|
} |
@ -0,0 +1,68 @@ |
|||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
use crate::{pyparams, pytools}; |
||||||
|
use pyo3::prelude::*; |
||||||
|
use reCTBN::{ctbn, network::Network, params, tools, params::Params}; |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyCtbnNetwork { |
||||||
|
#[new] |
||||||
|
pub fn new() -> Self { |
||||||
|
PyCtbnNetwork(ctbn::CtbnNetwork::new()) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn add_node(&mut self, n: pyparams::PyParams) { |
||||||
|
self.0.add_node(n.0); |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_number_of_nodes(&self) -> usize { |
||||||
|
self.0.get_number_of_nodes() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn add_edge(&mut self, parent: usize, child: usize) { |
||||||
|
self.0.add_edge(parent, child); |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_node_indices(&self) -> BTreeSet<usize> { |
||||||
|
self.0.get_node_indices().collect() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.0.get_parent_set(node) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||||
|
self.0.get_children_set(node) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) { |
||||||
|
match &n.0 { |
||||||
|
Params::DiscreteStatesContinousTime(new_p) => { |
||||||
|
if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){ |
||||||
|
p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap(); |
||||||
|
|
||||||
|
} |
||||||
|
else { |
||||||
|
panic!("Node type mismatch") |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn trajectory_generator( |
||||||
|
&self, |
||||||
|
n_trajectories: u64, |
||||||
|
t_end: f64, |
||||||
|
seed: Option<u64>, |
||||||
|
) -> pytools::PyDataset { |
||||||
|
pytools::PyDataset(tools::trajectory_generator( |
||||||
|
&self.0, |
||||||
|
n_trajectories, |
||||||
|
t_end, |
||||||
|
seed, |
||||||
|
)) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,102 @@ |
|||||||
|
use numpy::{self, ToPyArray}; |
||||||
|
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||||
|
use reCTBN::params::{self, ParamsTrait}; |
||||||
|
use std::collections::BTreeSet; |
||||||
|
|
||||||
|
pub struct PyParamsError(params::ParamsError); |
||||||
|
|
||||||
|
impl From<PyParamsError> for PyErr { |
||||||
|
fn from(error: PyParamsError) -> Self { |
||||||
|
PyValueError::new_err(error.0.to_string()) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl From<params::ParamsError> for PyParamsError { |
||||||
|
fn from(other: params::ParamsError) -> Self { |
||||||
|
Self(other) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
pub struct PyStateType(pub params::StateType); |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
#[derive(Clone)] |
||||||
|
pub struct PyParams(pub params::Params); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyParams { |
||||||
|
#[staticmethod] |
||||||
|
pub fn new_discrete_state_continous_time(p: PyDiscreteStateContinousTime) -> Self{ |
||||||
|
PyParams(params::Params::DiscreteStatesContinousTime(p.0)) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_reserved_space_as_parent(&self) -> usize { |
||||||
|
self.0.get_reserved_space_as_parent() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_label(&self) -> String { |
||||||
|
self.0.get_label().to_string() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/// DiscreteStatesContinousTime.
|
||||||
|
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
|
||||||
|
/// following elements:
|
||||||
|
/// - **domain**: an ordered and exhaustive set of possible states
|
||||||
|
/// - **cim**: Conditional Intensity Matrix
|
||||||
|
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
|
||||||
|
/// learning task and are composed by:
|
||||||
|
/// - **transitions**: number of transitions from one state to another given a specific
|
||||||
|
/// realization of the parent set
|
||||||
|
/// - **residence_time**: permanence time in each possible states given a specific
|
||||||
|
/// realization of the parent set
|
||||||
|
#[derive(Clone)] |
||||||
|
#[pyclass] |
||||||
|
pub struct PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams); |
||||||
|
|
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyDiscreteStateContinousTime { |
||||||
|
#[new] |
||||||
|
pub fn new(label: String, domain: BTreeSet<String>) -> Self { |
||||||
|
PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams::new(label, domain)) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_cim<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<f64>> { |
||||||
|
match self.0.get_cim() { |
||||||
|
Some(x) => Some(x.to_pyarray(py)), |
||||||
|
None => None, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_cim<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<f64>) -> Result<(), PyParamsError> { |
||||||
|
self.0.set_cim(cim.as_array().to_owned())?; |
||||||
|
Ok(()) |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
pub fn get_transitions<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<usize>> { |
||||||
|
match self.0.get_transitions() { |
||||||
|
Some(x) => Some(x.to_pyarray(py)), |
||||||
|
None => None, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_transitions<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<usize>){ |
||||||
|
self.0.set_transitions(cim.as_array().to_owned()); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
pub fn get_residence_time<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray2<f64>> { |
||||||
|
match self.0.get_residence_time() { |
||||||
|
Some(x) => Some(x.to_pyarray(py)), |
||||||
|
None => None, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
pub fn set_residence_time<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray2<f64>) { |
||||||
|
self.0.set_residence_time(cim.as_array().to_owned()); |
||||||
|
} |
||||||
|
|
||||||
|
} |
@ -0,0 +1,50 @@ |
|||||||
|
use numpy::{self, ToPyArray}; |
||||||
|
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||||
|
use reCTBN::{tools, network}; |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
#[derive(Clone)] |
||||||
|
pub struct PyTrajectory(pub tools::Trajectory); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyTrajectory { |
||||||
|
#[new] |
||||||
|
pub fn new( |
||||||
|
time: numpy::PyReadonlyArray1<f64>, |
||||||
|
events: numpy::PyReadonlyArray2<usize>, |
||||||
|
) -> PyTrajectory { |
||||||
|
PyTrajectory(tools::Trajectory::new( |
||||||
|
time.as_array().to_owned(), |
||||||
|
events.as_array().to_owned(), |
||||||
|
)) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1<f64> { |
||||||
|
self.0.get_time().to_pyarray(py) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2<usize> { |
||||||
|
self.0.get_events().to_pyarray(py) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
#[pyclass] |
||||||
|
pub struct PyDataset(pub tools::Dataset); |
||||||
|
|
||||||
|
#[pymethods] |
||||||
|
impl PyDataset { |
||||||
|
#[new] |
||||||
|
pub fn new(trajectories: Vec<PyTrajectory>) -> PyDataset { |
||||||
|
PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect())) |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_number_of_trajectories(&self) -> usize { |
||||||
|
self.0.get_trajectories().len() |
||||||
|
} |
||||||
|
|
||||||
|
pub fn get_trajectory(&self, idx: usize) -> PyTrajectory { |
||||||
|
PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone()) |
||||||
|
} |
||||||
|
|
||||||
|
} |
||||||
|
|
Loading…
Reference in new issue