|
|
|
@ -1,9 +1,12 @@ |
|
|
|
|
//! Contains commonly used methods used across the crate.
|
|
|
|
|
|
|
|
|
|
use ndarray::prelude::*; |
|
|
|
|
use std::ops::{DivAssign, MulAssign, Range}; |
|
|
|
|
|
|
|
|
|
use ndarray::{Array, Array1, Array2, Array3, Axis}; |
|
|
|
|
use rand::{Rng, SeedableRng}; |
|
|
|
|
use rand_chacha::ChaCha8Rng; |
|
|
|
|
|
|
|
|
|
use crate::params::ParamsTrait; |
|
|
|
|
use crate::process::NetworkProcess; |
|
|
|
|
use crate::sampling::{ForwardSampler, Sampler}; |
|
|
|
|
use crate::{params, process}; |
|
|
|
@ -151,3 +154,90 @@ impl RandomGraphGenerator for UniformGraphGenerator { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub trait RandomParametersGenerator { |
|
|
|
|
fn new(interval: Range<f64>, seed: Option<u64>) -> Self; |
|
|
|
|
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct UniformParametersGenerator { |
|
|
|
|
interval: Range<f64>, |
|
|
|
|
rng: ChaCha8Rng, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl RandomParametersGenerator for UniformParametersGenerator { |
|
|
|
|
fn new(interval: Range<f64>, seed: Option<u64>) -> UniformParametersGenerator { |
|
|
|
|
if interval.start < 0.0 || interval.end < 0.0 { |
|
|
|
|
panic!( |
|
|
|
|
"Interval must be entirely less or equal than 0, got {}..{}.", |
|
|
|
|
interval.start, interval.end |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
let rng: ChaCha8Rng = match seed { |
|
|
|
|
Some(seed) => SeedableRng::seed_from_u64(seed), |
|
|
|
|
None => SeedableRng::from_entropy(), |
|
|
|
|
}; |
|
|
|
|
UniformParametersGenerator { interval, rng } |
|
|
|
|
} |
|
|
|
|
fn generate_parameters<T: NetworkProcess>(&mut self, net: &mut T) { |
|
|
|
|
for node in net.get_node_indices() { |
|
|
|
|
let parent_set = net.get_parent_set(node); |
|
|
|
|
let parent_set_state_space_cardinality: usize = parent_set |
|
|
|
|
.iter() |
|
|
|
|
.map(|x| net.get_node(*x).get_reserved_space_as_parent()) |
|
|
|
|
.product(); |
|
|
|
|
println!( |
|
|
|
|
"parent_set_state_space_cardinality = {}", |
|
|
|
|
parent_set_state_space_cardinality |
|
|
|
|
); |
|
|
|
|
let node_domain_cardinality = net.get_node(node).get_reserved_space_as_parent(); |
|
|
|
|
println!("node_domain_cardinality = {}", node_domain_cardinality); |
|
|
|
|
let cim_single_param_range = (self.interval.start / node_domain_cardinality as f64) |
|
|
|
|
..=(self.interval.end / node_domain_cardinality as f64); |
|
|
|
|
println!("cim_single_param_range = {:?}", cim_single_param_range); |
|
|
|
|
|
|
|
|
|
let mut cim = Array3::<f64>::from_shape_fn( |
|
|
|
|
( |
|
|
|
|
parent_set_state_space_cardinality, |
|
|
|
|
node_domain_cardinality, |
|
|
|
|
node_domain_cardinality, |
|
|
|
|
), |
|
|
|
|
|_| self.rng.gen(), |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
//let diagonal = cim.axis_iter(Axis(0));
|
|
|
|
|
cim.axis_iter_mut(Axis(0)) |
|
|
|
|
.for_each(|mut x| x.diag_mut().iter_mut().for_each(|x| println!("{x}"))); |
|
|
|
|
cim.axis_iter_mut(Axis(0)).for_each(|mut x| { |
|
|
|
|
x.diag_mut().fill(0.0); |
|
|
|
|
let sum_axis = x.sum_axis(Axis(0)); |
|
|
|
|
//let division = 1.0 / &sum_axis;
|
|
|
|
|
x.div_assign(&sum_axis); |
|
|
|
|
println!("{}", x); |
|
|
|
|
let diag = Array1::<f64>::from_shape_fn(node_domain_cardinality, |_| { |
|
|
|
|
self.rng.gen_range(self.interval.clone()) |
|
|
|
|
}); |
|
|
|
|
x.mul_assign(&diag); |
|
|
|
|
println!("{}", x); |
|
|
|
|
x.diag_mut().assign(&-diag) |
|
|
|
|
}); |
|
|
|
|
cim.axis_iter_mut(Axis(0)) |
|
|
|
|
.for_each(|x| x.diag().iter().for_each(|x| println!("{x}"))); |
|
|
|
|
|
|
|
|
|
println!("Sum Axis"); |
|
|
|
|
cim.axis_iter_mut(Axis(0)) |
|
|
|
|
.for_each(|x| x.sum_axis(Axis(0)).iter().for_each(|x| println!("{x}"))); |
|
|
|
|
println!("Matrices"); |
|
|
|
|
cim.axis_iter_mut(Axis(0)) |
|
|
|
|
.for_each(|x| x.iter().for_each(|x| println!("{}", x))); |
|
|
|
|
//.any(|x| x.diag().iter().any(|x| x >= &0.0))
|
|
|
|
|
|
|
|
|
|
//println!("{:?}", diagonal);
|
|
|
|
|
match &mut net.get_node_mut(node) { |
|
|
|
|
params::Params::DiscreteStatesContinousTime(param) => { |
|
|
|
|
assert_eq!(Ok(()), param.set_cim(cim)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|