pull/59/head
AlessandroBregoli 2 years ago
parent a07dc214ce
commit 622cd305d0
  1. 2
      src/lib.rs
  2. 15
      src/sampling.rs
  3. 2
      src/tools.rs

@ -6,6 +6,6 @@ pub mod ctbn;
pub mod network; pub mod network;
pub mod parameter_learning; pub mod parameter_learning;
pub mod params; pub mod params;
pub mod sampling;
pub mod structure_learning; pub mod structure_learning;
pub mod tools; pub mod tools;
pub mod sampling;

@ -1,5 +1,5 @@
use crate::{ use crate::{
network::{Network}, network::Network,
params::{self, ParamsTrait}, params::{self, ParamsTrait},
}; };
use rand::SeedableRng; use rand::SeedableRng;
@ -65,7 +65,8 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
} }
} }
let next_node_transition = self.next_transitions let next_node_transition = self
.next_transitions
.iter() .iter()
.enumerate() .enumerate()
.min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap()) .min_by(|x, y| x.1.unwrap().partial_cmp(&y.1.unwrap()).unwrap())
@ -74,12 +75,15 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
self.current_time = self.next_transitions[next_node_transition].unwrap().clone(); self.current_time = self.next_transitions[next_node_transition].unwrap().clone();
self.current_state[next_node_transition] = self.net self.current_state[next_node_transition] = self
.net
.get_node(next_node_transition) .get_node(next_node_transition)
.get_random_state( .get_random_state(
self.net.get_node(next_node_transition) self.net
.get_node(next_node_transition)
.state_to_index(&self.current_state[next_node_transition]), .state_to_index(&self.current_state[next_node_transition]),
self.net.get_param_index_network(next_node_transition, &self.current_state), self.net
.get_param_index_network(next_node_transition, &self.current_state),
&mut self.rng, &mut self.rng,
) )
.unwrap(); .unwrap();
@ -90,7 +94,6 @@ impl<'a, T: Network> Iterator for ForwardSampler<'a, T> {
self.next_transitions[child] = None; self.next_transitions[child] = None;
} }
Some((ret_time, ret_state)) Some((ret_time, ret_state))
} }
} }

@ -1,6 +1,6 @@
use ndarray::prelude::*; use ndarray::prelude::*;
use crate::sampling::{Sampler, ForwardSampler}; use crate::sampling::{ForwardSampler, Sampler};
use crate::{network, params}; use crate::{network, params};
pub struct Trajectory { pub struct Trajectory {

Loading…
Cancel
Save