Compare commits
3 Commits
dev
...
62-feature
Author | SHA1 | Date |
---|---|---|
AlessandroBregoli | e1af14b620 | 2 years ago |
AlessandroBregoli | aa630bf9e9 | 2 years ago |
AlessandroBregoli | 3dae67a80c | 2 years ago |
@ -1,3 +1,4 @@ |
||||
/target |
||||
Cargo.lock |
||||
.vscode |
||||
poetry.lock |
||||
|
@ -0,0 +1,16 @@ |
||||
[tool.poetry] |
||||
name = "rectbnpy" |
||||
version = "0.1.0" |
||||
description = "" |
||||
authors = ["AlessandroBregoli <alessandroxciv@gmail.com>"] |
||||
readme = "README.md" |
||||
|
||||
[tool.poetry.dependencies] |
||||
python = "^3.10" |
||||
maturin = "^0.13.3" |
||||
numpy = "^1.23.3" |
||||
|
||||
|
||||
[build-system] |
||||
requires = ["poetry-core"] |
||||
build-backend = "poetry.core.masonry.api" |
@ -0,0 +1,69 @@ |
||||
name: CI |
||||
|
||||
on: |
||||
push: |
||||
branches: |
||||
- main |
||||
- master |
||||
pull_request: |
||||
|
||||
jobs: |
||||
linux: |
||||
runs-on: ubuntu-latest |
||||
steps: |
||||
- uses: actions/checkout@v3 |
||||
- uses: messense/maturin-action@v1 |
||||
with: |
||||
manylinux: auto |
||||
command: build |
||||
args: --release --sdist -o dist --find-interpreter |
||||
- name: Upload wheels |
||||
uses: actions/upload-artifact@v2 |
||||
with: |
||||
name: wheels |
||||
path: dist |
||||
|
||||
windows: |
||||
runs-on: windows-latest |
||||
steps: |
||||
- uses: actions/checkout@v3 |
||||
- uses: messense/maturin-action@v1 |
||||
with: |
||||
command: build |
||||
args: --release -o dist --find-interpreter |
||||
- name: Upload wheels |
||||
uses: actions/upload-artifact@v2 |
||||
with: |
||||
name: wheels |
||||
path: dist |
||||
|
||||
macos: |
||||
runs-on: macos-latest |
||||
steps: |
||||
- uses: actions/checkout@v3 |
||||
- uses: messense/maturin-action@v1 |
||||
with: |
||||
command: build |
||||
args: --release -o dist --universal2 --find-interpreter |
||||
- name: Upload wheels |
||||
uses: actions/upload-artifact@v2 |
||||
with: |
||||
name: wheels |
||||
path: dist |
||||
|
||||
release: |
||||
name: Release |
||||
runs-on: ubuntu-latest |
||||
if: "startsWith(github.ref, 'refs/tags/')" |
||||
needs: [ macos, windows, linux ] |
||||
steps: |
||||
- uses: actions/download-artifact@v2 |
||||
with: |
||||
name: wheels |
||||
- name: Publish to PyPI |
||||
uses: messense/maturin-action@v1 |
||||
env: |
||||
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} |
||||
with: |
||||
command: upload |
||||
args: --skip-existing * |
@ -0,0 +1,72 @@ |
||||
/target |
||||
|
||||
# Byte-compiled / optimized / DLL files |
||||
__pycache__/ |
||||
.pytest_cache/ |
||||
*.py[cod] |
||||
|
||||
# C extensions |
||||
*.so |
||||
|
||||
# Distribution / packaging |
||||
.Python |
||||
.venv/ |
||||
env/ |
||||
bin/ |
||||
build/ |
||||
develop-eggs/ |
||||
dist/ |
||||
eggs/ |
||||
lib/ |
||||
lib64/ |
||||
parts/ |
||||
sdist/ |
||||
var/ |
||||
include/ |
||||
man/ |
||||
venv/ |
||||
*.egg-info/ |
||||
.installed.cfg |
||||
*.egg |
||||
|
||||
# Installer logs |
||||
pip-log.txt |
||||
pip-delete-this-directory.txt |
||||
pip-selfcheck.json |
||||
|
||||
# Unit test / coverage reports |
||||
htmlcov/ |
||||
.tox/ |
||||
.coverage |
||||
.cache |
||||
nosetests.xml |
||||
coverage.xml |
||||
|
||||
# Translations |
||||
*.mo |
||||
|
||||
# Mr Developer |
||||
.mr.developer.cfg |
||||
.project |
||||
.pydevproject |
||||
|
||||
# Rope |
||||
.ropeproject |
||||
|
||||
# Django stuff: |
||||
*.log |
||||
*.pot |
||||
|
||||
.DS_Store |
||||
|
||||
# Sphinx documentation |
||||
docs/_build/ |
||||
|
||||
# PyCharm |
||||
.idea/ |
||||
|
||||
# VSCode |
||||
.vscode/ |
||||
|
||||
# Pyenv |
||||
.python-version |
@ -0,0 +1,15 @@ |
||||
[package] |
||||
name = "reCTBNpy" |
||||
version = "0.1.0" |
||||
edition = "2021" |
||||
|
||||
|
||||
[lib] |
||||
crate-type = ["cdylib"] |
||||
|
||||
[workspace] |
||||
|
||||
[dependencies] |
||||
pyo3 = { version = "0.17.1", features = ["extension-module"] } |
||||
numpy = "*" |
||||
reCTBN = { path="../reCTBN" } |
@ -0,0 +1,14 @@ |
||||
[build-system] |
||||
requires = ["maturin>=0.13,<0.14"] |
||||
build-backend = "maturin" |
||||
|
||||
[project] |
||||
name = "reCTBNpy" |
||||
requires-python = ">=3.7" |
||||
classifiers = [ |
||||
"Programming Language :: Rust", |
||||
"Programming Language :: Python :: Implementation :: CPython", |
||||
"Programming Language :: Python :: Implementation :: PyPy", |
||||
] |
||||
|
||||
|
@ -0,0 +1,21 @@ |
||||
use pyo3::prelude::*; |
||||
pub mod pyctbn; |
||||
pub mod pyparams; |
||||
pub mod pytools; |
||||
|
||||
|
||||
|
||||
/// A Python module implemented in Rust.
|
||||
#[pymodule] |
||||
fn reCTBNpy(py: Python, m: &PyModule) -> PyResult<()> { |
||||
let network_module = PyModule::new(py, "network")?; |
||||
network_module.add_class::<pyctbn::PyCtbnNetwork>()?; |
||||
m.add_submodule(network_module)?; |
||||
|
||||
let params_module = PyModule::new(py, "params")?; |
||||
params_module.add_class::<pyparams::PyDiscreteStateContinousTime>()?; |
||||
params_module.add_class::<pyparams::PyStateType>()?; |
||||
params_module.add_class::<pyparams::PyParams>()?; |
||||
m.add_submodule(params_module)?; |
||||
Ok(()) |
||||
} |
@ -0,0 +1,68 @@ |
||||
use std::collections::BTreeSet; |
||||
|
||||
use crate::{pyparams, pytools}; |
||||
use pyo3::prelude::*; |
||||
use reCTBN::{ctbn, network::Network, params, tools, params::Params}; |
||||
|
||||
#[pyclass] |
||||
pub struct PyCtbnNetwork(pub ctbn::CtbnNetwork); |
||||
|
||||
#[pymethods] |
||||
impl PyCtbnNetwork { |
||||
#[new] |
||||
pub fn new() -> Self { |
||||
PyCtbnNetwork(ctbn::CtbnNetwork::new()) |
||||
} |
||||
|
||||
pub fn add_node(&mut self, n: pyparams::PyParams) { |
||||
self.0.add_node(n.0); |
||||
} |
||||
|
||||
pub fn get_number_of_nodes(&self) -> usize { |
||||
self.0.get_number_of_nodes() |
||||
} |
||||
|
||||
pub fn add_edge(&mut self, parent: usize, child: usize) { |
||||
self.0.add_edge(parent, child); |
||||
} |
||||
|
||||
pub fn get_node_indices(&self) -> BTreeSet<usize> { |
||||
self.0.get_node_indices().collect() |
||||
} |
||||
|
||||
pub fn get_parent_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.0.get_parent_set(node) |
||||
} |
||||
|
||||
pub fn get_children_set(&self, node: usize) -> BTreeSet<usize> { |
||||
self.0.get_children_set(node) |
||||
} |
||||
|
||||
pub fn set_node(&mut self, node_idx: usize, n: pyparams::PyParams) { |
||||
match &n.0 { |
||||
Params::DiscreteStatesContinousTime(new_p) => { |
||||
if let Params::DiscreteStatesContinousTime(p) = self.0.get_node_mut(node_idx){ |
||||
p.set_cim(new_p.get_cim().as_ref().unwrap().clone()).unwrap(); |
||||
|
||||
} |
||||
else { |
||||
panic!("Node type mismatch") |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub fn trajectory_generator( |
||||
&self, |
||||
n_trajectories: u64, |
||||
t_end: f64, |
||||
seed: Option<u64>, |
||||
) -> pytools::PyDataset { |
||||
pytools::PyDataset(tools::trajectory_generator( |
||||
&self.0, |
||||
n_trajectories, |
||||
t_end, |
||||
seed, |
||||
)) |
||||
} |
||||
} |
@ -0,0 +1,102 @@ |
||||
use numpy::{self, ToPyArray}; |
||||
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||
use reCTBN::params::{self, ParamsTrait}; |
||||
use std::collections::BTreeSet; |
||||
|
||||
pub struct PyParamsError(params::ParamsError); |
||||
|
||||
impl From<PyParamsError> for PyErr { |
||||
fn from(error: PyParamsError) -> Self { |
||||
PyValueError::new_err(error.0.to_string()) |
||||
} |
||||
} |
||||
|
||||
impl From<params::ParamsError> for PyParamsError { |
||||
fn from(other: params::ParamsError) -> Self { |
||||
Self(other) |
||||
} |
||||
} |
||||
|
||||
#[pyclass] |
||||
pub struct PyStateType(pub params::StateType); |
||||
|
||||
#[pyclass] |
||||
#[derive(Clone)] |
||||
pub struct PyParams(pub params::Params); |
||||
|
||||
#[pymethods] |
||||
impl PyParams { |
||||
#[staticmethod] |
||||
pub fn new_discrete_state_continous_time(p: PyDiscreteStateContinousTime) -> Self{ |
||||
PyParams(params::Params::DiscreteStatesContinousTime(p.0)) |
||||
} |
||||
|
||||
pub fn get_reserved_space_as_parent(&self) -> usize { |
||||
self.0.get_reserved_space_as_parent() |
||||
} |
||||
|
||||
pub fn get_label(&self) -> String { |
||||
self.0.get_label().to_string() |
||||
} |
||||
} |
||||
|
||||
/// DiscreteStatesContinousTime.
|
||||
/// This represents the parameters of a classical discrete node for ctbn and it's composed by the
|
||||
/// following elements:
|
||||
/// - **domain**: an ordered and exhaustive set of possible states
|
||||
/// - **cim**: Conditional Intensity Matrix
|
||||
/// - **Sufficient Statistics**: the sufficient statistics are mainly used during the parameter
|
||||
/// learning task and are composed by:
|
||||
/// - **transitions**: number of transitions from one state to another given a specific
|
||||
/// realization of the parent set
|
||||
/// - **residence_time**: permanence time in each possible states given a specific
|
||||
/// realization of the parent set
|
||||
#[derive(Clone)] |
||||
#[pyclass] |
||||
pub struct PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams); |
||||
|
||||
|
||||
#[pymethods] |
||||
impl PyDiscreteStateContinousTime { |
||||
#[new] |
||||
pub fn new(label: String, domain: BTreeSet<String>) -> Self { |
||||
PyDiscreteStateContinousTime(params::DiscreteStatesContinousTimeParams::new(label, domain)) |
||||
} |
||||
|
||||
pub fn get_cim<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<f64>> { |
||||
match self.0.get_cim() { |
||||
Some(x) => Some(x.to_pyarray(py)), |
||||
None => None, |
||||
} |
||||
} |
||||
|
||||
pub fn set_cim<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<f64>) -> Result<(), PyParamsError> { |
||||
self.0.set_cim(cim.as_array().to_owned())?; |
||||
Ok(()) |
||||
} |
||||
|
||||
|
||||
pub fn get_transitions<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray3<usize>> { |
||||
match self.0.get_transitions() { |
||||
Some(x) => Some(x.to_pyarray(py)), |
||||
None => None, |
||||
} |
||||
} |
||||
|
||||
pub fn set_transitions<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray3<usize>){ |
||||
self.0.set_transitions(cim.as_array().to_owned()); |
||||
} |
||||
|
||||
|
||||
pub fn get_residence_time<'py>(&self, py: Python<'py>) -> Option<&'py numpy::PyArray2<f64>> { |
||||
match self.0.get_residence_time() { |
||||
Some(x) => Some(x.to_pyarray(py)), |
||||
None => None, |
||||
} |
||||
} |
||||
|
||||
pub fn set_residence_time<'py>(&mut self, py: Python<'py>, cim: numpy::PyReadonlyArray2<f64>) { |
||||
self.0.set_residence_time(cim.as_array().to_owned()); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,50 @@ |
||||
use numpy::{self, ToPyArray}; |
||||
use pyo3::{exceptions::PyValueError, prelude::*}; |
||||
use reCTBN::{tools, network}; |
||||
|
||||
#[pyclass] |
||||
#[derive(Clone)] |
||||
pub struct PyTrajectory(pub tools::Trajectory); |
||||
|
||||
#[pymethods] |
||||
impl PyTrajectory { |
||||
#[new] |
||||
pub fn new( |
||||
time: numpy::PyReadonlyArray1<f64>, |
||||
events: numpy::PyReadonlyArray2<usize>, |
||||
) -> PyTrajectory { |
||||
PyTrajectory(tools::Trajectory::new( |
||||
time.as_array().to_owned(), |
||||
events.as_array().to_owned(), |
||||
)) |
||||
} |
||||
|
||||
pub fn get_time<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray1<f64> { |
||||
self.0.get_time().to_pyarray(py) |
||||
} |
||||
|
||||
pub fn get_events<'py>(&self, py: Python<'py>) -> &'py numpy::PyArray2<usize> { |
||||
self.0.get_events().to_pyarray(py) |
||||
} |
||||
} |
||||
|
||||
#[pyclass] |
||||
pub struct PyDataset(pub tools::Dataset); |
||||
|
||||
#[pymethods] |
||||
impl PyDataset { |
||||
#[new] |
||||
pub fn new(trajectories: Vec<PyTrajectory>) -> PyDataset { |
||||
PyDataset(tools::Dataset::new(trajectories.into_iter().map(|x| x.0).collect())) |
||||
} |
||||
|
||||
pub fn get_number_of_trajectories(&self) -> usize { |
||||
self.0.get_trajectories().len() |
||||
} |
||||
|
||||
pub fn get_trajectory(&self, idx: usize) -> PyTrajectory { |
||||
PyTrajectory(self.0.get_trajectories().get(idx).unwrap().clone()) |
||||
} |
||||
|
||||
} |
||||
|
Loading…
Reference in new issue