Skip to content

Commit

Permalink
major refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rileysu committed Nov 5, 2023
1 parent ed377f8 commit 49cd746
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 112 deletions.
196 changes: 133 additions & 63 deletions src/context/comp_graph.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::{HashSet, HashMap};

use slotmap::{SlotMap, new_key_type};
use thiserror::Error;

use crate::engine::{tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory}, Engine, EngineError};

Expand Down Expand Up @@ -36,10 +37,14 @@ impl<T: AllowedUnit> Node<T> {
pub fn edge(&self) -> &Edge<T> {
&self.edge
}

pub fn is_root(&self) -> bool {
*self.edge() == Edge::Root
}
}

//Tensor might be able to be combined inside edge so root without a defined tensor isn't possible
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Edge<T: AllowedUnit> {
Root,

Expand All @@ -66,12 +71,12 @@ impl<T: AllowedUnit> CompGraph<T> {
}
}

pub fn get_node(&self, node_idx: NodeKey) -> Option<&Node<T>> {
self.nodes.get(node_idx)
pub fn get_node(&self, node_key: NodeKey) -> Option<&Node<T>> {
self.nodes.get(node_key)
}

pub fn get_node_mut(&mut self, node_idx: NodeKey) -> Option<&mut Node<T>> {
self.nodes.get_mut(node_idx)
pub fn get_node_mut(&mut self, node_key: NodeKey) -> Option<&mut Node<T>> {
self.nodes.get_mut(node_key)
}

//Root is a node that is a starting point for computation
Expand All @@ -83,74 +88,125 @@ impl<T: AllowedUnit> CompGraph<T> {
self.nodes.insert(Node::create_node(edge))
}

//Single layer computation otherwise should throw an error
fn compute_tensor(&self, target_key: NodeKey) -> Result<EngineTensor<T>, ComputationGraphError> {
let target_node = self.get_node(target_key).ok_or(ComputationGraphError::NodeKeyDoesNotExist(target_key))?;

match target_node.edge() {
Edge::Root => Err(ComputationGraphError::RootNodeNotComputed(target_key)),
Edge::Abs(a_key, op) |
Edge::Neg(a_key, op) => {
let a_node = self.get_node(*a_key).ok_or(ComputationGraphError::ParentNodeDoesNotExist(*a_key))?;

op(a_node.tensor().ok_or(ComputationGraphError::ParentNodeNotComputed(*a_key))?).map_err(|e| ComputationGraphError::from(e))
},
Edge::Add(a_key, b_key, op) |
Edge::Sub(a_key, b_key, op) |
Edge::Mul(a_key, b_key, op) |
Edge::Div(a_key, b_key, op) => {
let a_node = self.get_node(*a_key).ok_or(ComputationGraphError::ParentNodeDoesNotExist(*a_key))?;
let b_node = self.get_node(*b_key).ok_or(ComputationGraphError::ParentNodeDoesNotExist(*b_key))?;

op(a_node.tensor().ok_or(ComputationGraphError::ParentNodeNotComputed(*a_key))?, b_node.tensor().ok_or(ComputationGraphError::ParentNodeNotComputed(*b_key))?).map_err(|e| ComputationGraphError::from(e))
},
}
}

//TODO Fix to handle errors and clean up code
//Kahn's Algorithm
pub fn populating_eval(&mut self, target_key: NodeKey) {
let open_nodes = HashSet::<NodeKey>::new();
let closed_nodes = HashSet::<NodeKey>::new();
let nodes_to_child = HashMap::<NodeKey, Vec<NodeKey>>::new();



todo!()
let mut open_nodes = Vec::<NodeKey>::new();

let mut node_children = HashMap::<NodeKey, Vec<NodeKey>>::new();
let mut visited_nodes: HashSet<NodeKey> = HashSet::new();
let mut to_eval_nodes = vec![target_key];

//Populate start nodes and node_children
while let Some(node_key) = to_eval_nodes.pop() {
let node = self.get_node(node_key).unwrap();

match node.edge() {
Edge::Root => {
open_nodes.push(node_key);
},
Edge::Abs(a_key, _) | Edge::Neg(a_key, _) => {
if let Some(children) = node_children.get_mut(&a_key) {
children.push(node_key);
} else {
node_children.insert(*a_key, vec![node_key]);
}

if !visited_nodes.contains(&a_key) {
to_eval_nodes.push(*a_key);
}
},
Edge::Add(a_key, b_key, _) | Edge::Sub(a_key, b_key, _) | Edge::Mul(a_key, b_key, _) | Edge::Div(a_key, b_key, _) => {
if let Some(children) = node_children.get_mut(&a_key) {
children.push(node_key);
} else {
node_children.insert(*a_key, vec![node_key]);
}

if b_key != a_key {
if let Some(children) = node_children.get_mut(&b_key) {
children.push(node_key);
} else {
node_children.insert(*b_key, vec![node_key]);
}
}

if !visited_nodes.contains(&a_key) {
to_eval_nodes.push(*a_key);
visited_nodes.insert(*a_key);
}

if !visited_nodes.contains(&b_key) {
to_eval_nodes.push(*b_key);
visited_nodes.insert(*a_key);
}
},
}
}

/*
let mut to_eval = vec![target_idx];
let mut discovered = vec![target_idx];
let mut processed_nodes = HashSet::<NodeKey>::from_iter(open_nodes.clone());
let mut sorted_nodes = Vec::<NodeKey>::new();

while let Some(node_idx) = to_eval.pop() {
let node = self.get_node(node_idx).unwrap();
while let Some(node_key) = open_nodes.pop() {
sorted_nodes.push(node_key);

let mut push_to_eval_visited = |idxs: &[NodeIndex]| {
for idx in idxs {
to_eval.push(*idx);
visited.push(*idx);
}
};
//If the tensor is not populated within the comp graph
if node.tensor().is_none() {
match *node.edge() {
Edge::Root => panic!(),
Edge::Abs(a_idx, _) => push_to_eval_visited(&[a_idx]),
Edge::Neg(a_idx, _) => push_to_eval_visited(&[a_idx]),
Edge::Add(a_idx, b_idx, _) => push_to_eval_visited(&[a_idx, b_idx]),
Edge::Sub(a_idx, b_idx, _) => push_to_eval_visited(&[a_idx, b_idx]),
Edge::Mul(a_idx, b_idx, _) => push_to_eval_visited(&[a_idx, b_idx]),
Edge::Div(a_idx, b_idx, _) => push_to_eval_visited(&[a_idx, b_idx]),
}
if !self.get_node(node_key).unwrap().is_root() {
let comp_tensor = self.compute_tensor(node_key).unwrap();
self.get_node_mut(node_key).unwrap().set_tensor(comp_tensor);
}
}

println!("{:?}", visited);
//Reversing the DFS should evaluate nodes in an order that allows for computation down to the target node
for node_idx in visited.iter().copied().rev() {
let (node_no_tensor, node_edge) = {
let node = self.get_node(node_idx).unwrap();
(node.tensor().is_none(), *node.edge())
};
if node_no_tensor {
let comp_tensor = match node_edge {
Edge::Root => unreachable!(),
Edge::Abs(a_idx, op) => op(self.get_node(a_idx).unwrap().tensor().unwrap()),
Edge::Neg(a_idx, op) => op(self.get_node(a_idx).unwrap().tensor().unwrap()),
Edge::Add(a_idx, b_idx, op) => op(self.get_node(a_idx).unwrap().tensor().unwrap(), self.get_node(b_idx).unwrap().tensor().unwrap()),
Edge::Sub(a_idx, b_idx, op) => op(self.get_node(a_idx).unwrap().tensor().unwrap(), self.get_node(b_idx).unwrap().tensor().unwrap()),
Edge::Mul(a_idx, b_idx, op) => op(self.get_node(a_idx).unwrap().tensor().unwrap(), self.get_node(b_idx).unwrap().tensor().unwrap()),
Edge::Div(a_idx, b_idx, op) => op(self.get_node(a_idx).unwrap().tensor().unwrap(), self.get_node(b_idx).unwrap().tensor().unwrap()),
};
match comp_tensor {
Ok(tensor) => self.get_node_mut(node_idx).unwrap().set_tensor(tensor),
Err(_) => panic!(),
processed_nodes.insert(node_key);

if let Some(children_keys) = node_children.get(&node_key) {
for child_key in children_keys {
let child_node = self.get_node(*child_key).unwrap();

match child_node.edge() {
Edge::Root => {
//It should be impossible for a root to be a child
todo!()
},
Edge::Abs(_, _) |
Edge::Neg(_, _) => {
//Since we know we just processed the parent there is no need to check
open_nodes.push(*child_key);
},
Edge::Add(a_key, b_key, _) |
Edge::Sub(a_key, b_key, _) |
Edge::Mul(a_key, b_key, _) |
Edge::Div(a_key, b_key, _) => {
if processed_nodes.contains(a_key) && processed_nodes.contains(b_key) {
open_nodes.push(*child_key);
}
},
}
}
}
}*/
}
}

pub fn abs<E: Engine<T>, F: EngineTensorFactory<T>>(&mut self, a: NodeKey) -> NodeKey {
Expand All @@ -177,3 +233,17 @@ impl<T: AllowedUnit> CompGraph<T> {
self.create_node(Edge::Div(a, b, E::div::<F>))
}
}

#[derive(Error, Debug)]
pub enum ComputationGraphError {
#[error("Node key does not exist in this computation graph")]
NodeKeyDoesNotExist(NodeKey),
#[error("Root node doesn't contain computed tensor")]
RootNodeNotComputed(NodeKey),
#[error("Parent node does not exist in this computation graph")]
ParentNodeDoesNotExist(NodeKey),
#[error("Parent node not computed when expected to be")]
ParentNodeNotComputed(NodeKey),
#[error("Error in computation: {0}")]
ComputationError(#[from]EngineError),
}
3 changes: 0 additions & 3 deletions src/engine/basic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
use std::marker::PhantomData;
use super::tensor::factory::EngineTensorFactory;

pub struct Basic {}
6 changes: 4 additions & 2 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ pub trait Engine<T: AllowedUnit> {

#[derive(Error, Debug)]
pub enum EngineError {
#[error("the tensor of size {0} does not match {1}")]
ShapeMismatch(Shape, Shape)
#[error("The tensor of size {0} does not match {1}")]
ShapeMismatch(Shape, Shape),
#[error("The operation is not supported on this data type")]
OperationUnsupportedForType(),
}

Loading

0 comments on commit 49cd746

Please sign in to comment.