Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
rileysu committed Dec 16, 2023
1 parent 808dfcc commit 97478fe
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 150 deletions.
7 changes: 6 additions & 1 deletion ideas.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@
## TODO

- Refactor comp_graph to improve errors (ones with no nodekey) and reduce repeated code
- Need a way to remove graphs / subgraphs once they are finished
- Need a way to remove graphs / subgraphs once they are finished
- Maybe seperate graphs from context and create a new graph per training / inference iteration
- Use phantom to make tensors references to graph so it can't outlive graph
- Probably remove the distinction between context and comp_graph
- Dump graph on calculation
- Allow for recalc maybe
54 changes: 27 additions & 27 deletions src/context/comp_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use super::edge::Edge;

#[derive(Debug)]
pub struct Node<T: AllowedUnit> {
tensor: Option<EngineTensor<T>>,
tensor: Option<Box<dyn EngineTensor<T>>>,
edge: Edge<T>,
}

impl<T: AllowedUnit> Node<T> {
pub fn create_root(tensor: EngineTensor<T>) -> Self {
pub fn create_root(tensor: Box<dyn EngineTensor<T>>) -> Self {
Self {
tensor: Some(tensor),
edge: Edge::Root,
Expand All @@ -28,11 +28,11 @@ impl<T: AllowedUnit> Node<T> {
}
}

pub fn tensor(&self) -> Option<&EngineTensor<T>> {
self.tensor.as_ref()
pub fn tensor(&self) -> Option<&dyn EngineTensor<T>> {
self.tensor.as_deref()
}

pub fn set_tensor(&mut self, tensor: EngineTensor<T>) {
pub fn set_tensor(&mut self, tensor: Box<dyn EngineTensor<T>>) {
self.tensor = Some(tensor)
}

Expand Down Expand Up @@ -87,7 +87,7 @@ impl<T: AllowedUnit> CompGraph<T> {
}

//Root is a node that is a starting point for computation
pub fn create_root(&mut self, tensor: EngineTensor<T>) -> NodeKey {
pub fn create_root(&mut self, tensor: Box<dyn EngineTensor<T>>) -> NodeKey {
self.nodes.insert(Node::create_root(tensor))
}

Expand Down Expand Up @@ -193,7 +193,7 @@ impl<T: AllowedUnit> CompGraph<T> {

//Cache for calculated nodes
//Should be cleared once no dependencies left
let mut comp_cache = HashMap::<NodeKey, EngineTensor<T>>::new();
let mut comp_cache = HashMap::<NodeKey, Box<dyn EngineTensor<T>>>::new();

while let Some(node_key) = open.pop() {
let node = self.get_node(&node_key).ok_or(ComputationGraphError::NodeDoesNotExist(target))?;
Expand All @@ -202,7 +202,7 @@ impl<T: AllowedUnit> CompGraph<T> {
let comp_tensor = node.edge().compute_tensor(
|k| {
match comp_cache.get(&k) {
Some(tensor) => Ok(tensor),
Some(tensor) => Ok(tensor.as_ref()),
None => Ok(self.get_node(&k).ok_or(ComputationGraphError::NodeDoesNotExist(k))?.tensor().ok_or(ComputationGraphError::NodeNotComputed(k))?),
}
}
Expand Down Expand Up @@ -309,41 +309,41 @@ pub enum ComputationGraphError {
mod test {
use num::traits::Pow;

use crate::{engine::{tensor::factory::Array, basic::Basic}, helper::Shape};
use crate::{engine::{tensor::Array, basic::Basic}, helper::Shape};

use super::*;

pub fn init_simple_graph() -> (NodeKey, NodeKey, NodeKey, EngineTensor<f32>, CompGraph<f32>) {
pub fn init_simple_graph() -> (NodeKey, NodeKey, NodeKey, Box<dyn EngineTensor<f32>>, CompGraph<f32>) {
let mut graph = CompGraph::<f32>::new();

let root1 = graph.create_root(Array::from_slice([0.0, 1.0, 2.0, 3.0].as_slice(), Shape::from([2, 2].as_slice())));
let root2 = graph.create_root(Array::from_slice([0.0, 1.0, 2.0, 3.0].as_slice(), Shape::from([2, 2].as_slice())));

let added = graph.add::<Basic, Array>(root1, root2);
let added = graph.add::<Basic, Array<f32>>(root1, root2);

let expected = Array::from_slice([0.0, 2.0, 4.0, 6.0].as_slice(), Shape::from([2, 2].as_slice()));

return (root1, root2, added, expected, graph);
}

pub fn init_complex_graph() -> (NodeKey, EngineTensor<f32>, EngineTensor<f32>, CompGraph<f32>) {
pub fn init_complex_graph() -> (NodeKey, Box<dyn EngineTensor<f32>>, Box<dyn EngineTensor<f32>>, CompGraph<f32>) {
let mut graph = CompGraph::<f32>::new();

let root1 = graph.create_root(Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].as_slice(), Shape::from([3, 3].as_slice())));
let root2 = graph.create_root(Array::from_slice([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0].as_slice(), Shape::from([3, 3].as_slice())));
let root3 = graph.create_root(Array::from_slice([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0].as_slice(), Shape::from([3, 3].as_slice())));
let root4 = graph.create_root(Array::from_slice([1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0].as_slice(), Shape::from([3, 3].as_slice())));

let op1 = graph.div::<Basic, Array>(root4, root1);
let op2 = graph.mul::<Basic, Array>(op1, root2);
let op3 = graph.sub::<Basic, Array>(op2, root3);
let op1 = graph.div::<Basic, Array<f32>>(root4, root1);
let op2 = graph.mul::<Basic, Array<f32>>(op1, root2);
let op3 = graph.sub::<Basic, Array<f32>>(op2, root3);

let op4 = graph.mul_scalar::<Basic, Array>(2.0, op3);
let op5 = graph.div_scalar_rh::<Basic, Array>(op4, 2.0);
let op4 = graph.mul_scalar::<Basic, Array<f32>>(2.0, op3);
let op5 = graph.div_scalar_rh::<Basic, Array<f32>>(op4, 2.0);

let op6 = graph.mul::<Basic, Array>(op5, op5);
let op6 = graph.mul::<Basic, Array<f32>>(op5, op5);

let op7 = graph.div::<Basic, Array>(op6, root1);
let op7 = graph.div::<Basic, Array<f32>>(op6, root1);

return (op7, Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].as_slice(), Shape::from([3, 3].as_slice())), Array::from_slice([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0].as_slice(), Shape::from([3, 3].as_slice())), graph);
}
Expand Down Expand Up @@ -371,7 +371,7 @@ mod test {

assert!(node.tensor().is_some());

assert_eq!(*node.tensor().unwrap(), expected);
assert_eq!(node.tensor().unwrap(), expected.as_ref());

node.clear_tensor().unwrap();

Expand All @@ -383,7 +383,7 @@ mod test {

assert!(node.tensor().is_some());

assert_eq!(*node.tensor().unwrap(), expected);
assert_eq!(node.tensor().unwrap(), expected.as_ref());
}

#[test]
Expand All @@ -394,22 +394,22 @@ mod test {

let mut out = node_key;
for _ in 0..2_usize.pow(power as u32) {
out = graph.div::<Basic, Array>(out, out);
out = graph.div::<Basic, Array<_>>(out, out);
}

graph.non_populating_eval(out).unwrap();

let node = graph.get_node_mut(&out).unwrap();

assert_eq!(*node.tensor().unwrap(), expected_unit);
assert_eq!(*node.tensor().unwrap(), *expected_unit);

node.clear_tensor().unwrap();

graph.populating_eval(out).unwrap();

let node = graph.get_node(&out).unwrap();

assert_eq!(*node.tensor().unwrap(), expected_unit);
assert_eq!(*node.tensor().unwrap(), *expected_unit);
}

#[test]
Expand All @@ -429,7 +429,7 @@ mod test {
let a_key = keys[0];
let b_key = keys[1];

new_node_keys.push(graph.add::<Basic, Array>(a_key, b_key));
new_node_keys.push(graph.add::<Basic, Array<_>>(a_key, b_key));
}
}

Expand All @@ -441,14 +441,14 @@ mod test {

let node = graph.get_node_mut(node_key).unwrap();

assert_eq!(*node.tensor().unwrap(), expected);
assert_eq!(*node.tensor().unwrap(), *expected);

node.clear_tensor().unwrap();

graph.populating_eval(*node_key).unwrap();

let node = graph.get_node(node_key).unwrap();

assert_eq!(*node.tensor().unwrap(), expected);
assert_eq!(*node.tensor().unwrap(), *expected);
}
}
30 changes: 15 additions & 15 deletions src/context/edge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@ use super::comp_graph::{NodeKey, ComputationGraphError};
pub enum Edge<T: AllowedUnit> {
Root,

Abs(NodeKey, fn(&EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
Neg(NodeKey, fn(&EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),

AddScalar(T, NodeKey, fn(T, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
SubScalarLH(T, NodeKey, fn(T, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
SubScalarRH(NodeKey, T, fn(&EngineTensor<T>, T) -> Result<EngineTensor<T>, EngineError>),
MulScalar(T, NodeKey, fn(T, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
DivScalarLH(T, NodeKey, fn(T, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
DivScalarRH(NodeKey, T, fn(&EngineTensor<T>, T) -> Result<EngineTensor<T>, EngineError>),

Add(NodeKey, NodeKey, fn(&EngineTensor<T>, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
Sub(NodeKey, NodeKey, fn(&EngineTensor<T>, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
Mul(NodeKey, NodeKey, fn(&EngineTensor<T>, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
Div(NodeKey, NodeKey, fn(&EngineTensor<T>, &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>),
Abs(NodeKey, fn(&dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
Neg(NodeKey, fn(&dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),

AddScalar(T, NodeKey, fn(T, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
SubScalarLH(T, NodeKey, fn(T, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
SubScalarRH(NodeKey, T, fn(&dyn EngineTensor<T>, T) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
MulScalar(T, NodeKey, fn(T, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
DivScalarLH(T, NodeKey, fn(T, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
DivScalarRH(NodeKey, T, fn(&dyn EngineTensor<T>, T) -> Result<Box<dyn EngineTensor<T>>, EngineError>),

Add(NodeKey, NodeKey, fn(&dyn EngineTensor<T>, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
Sub(NodeKey, NodeKey, fn(&dyn EngineTensor<T>, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
Mul(NodeKey, NodeKey, fn(&dyn EngineTensor<T>, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
Div(NodeKey, NodeKey, fn(&dyn EngineTensor<T>, &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>),
}

impl<T: AllowedUnit> Edge<T> {
Expand All @@ -35,7 +35,7 @@ impl<T: AllowedUnit> Edge<T> {
}

//Single layer computation otherwise should throw an error
pub fn compute_tensor<'a, F: Fn(NodeKey) -> Result<&'a EngineTensor<T>, ComputationGraphError>>(&'a self, resolve: F) -> Result<EngineTensor<T>, ComputationGraphError> {
pub fn compute_tensor<'a, F: Fn(NodeKey) -> Result<&'a dyn EngineTensor<T>, ComputationGraphError>>(&'a self, resolve: F) -> Result<Box<dyn EngineTensor<T>>, ComputationGraphError> {
match self {
Edge::Root => Err(ComputationGraphError::RootNodeNotComputed()),
Edge::Abs(a_key, op) |
Expand Down
24 changes: 12 additions & 12 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ use thiserror::Error;
//Factory defines the unit as well as output tensor type
pub trait Engine<T: AllowedUnit> {
//Pointwise Single
fn abs<E: EngineTensorFactory<T>>(a: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn neg<E: EngineTensorFactory<T>>(a: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn abs<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn neg<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;

//Pointwise Scalar
fn add_scalar<E: EngineTensorFactory<T>>(s: T, a: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn sub_scalar_lh<E: EngineTensorFactory<T>>(s: T, a: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn sub_scalar_rh<E: EngineTensorFactory<T>>(a: &EngineTensor<T>, s: T) -> Result<EngineTensor<T>, EngineError>;
fn mul_scalar<E: EngineTensorFactory<T>>(s: T, a: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn div_scalar_lh<E: EngineTensorFactory<T>>(s: T, a: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn div_scalar_rh<E: EngineTensorFactory<T>>(a: &EngineTensor<T>, s: T) -> Result<EngineTensor<T>, EngineError>;
fn add_scalar<E: EngineTensorFactory<T>>(s: T, a: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn sub_scalar_lh<E: EngineTensorFactory<T>>(s: T, a: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn sub_scalar_rh<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>, s: T) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn mul_scalar<E: EngineTensorFactory<T>>(s: T, a: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn div_scalar_lh<E: EngineTensorFactory<T>>(s: T, a: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn div_scalar_rh<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>, s: T) -> Result<Box<dyn EngineTensor<T>>, EngineError>;

//Pointwise Double
fn add<E: EngineTensorFactory<T>>(a: &EngineTensor<T>, b: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn sub<E: EngineTensorFactory<T>>(a: &EngineTensor<T>, b: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn mul<E: EngineTensorFactory<T>>(a: &EngineTensor<T>, b: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn div<E: EngineTensorFactory<T>>(a: &EngineTensor<T>, b: &EngineTensor<T>) -> Result<EngineTensor<T>, EngineError>;
fn add<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>, b: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn sub<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>, b: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn mul<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>, b: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
fn div<E: EngineTensorFactory<T>>(a: &dyn EngineTensor<T>, b: &dyn EngineTensor<T>) -> Result<Box<dyn EngineTensor<T>>, EngineError>;
}

#[derive(Error, Debug)]
Expand Down
6 changes: 4 additions & 2 deletions src/engine/tensor/allowed_unit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::fmt::Debug;

use num::Num;

pub trait AllowedUnit: Num + Sized + Copy {}
impl<T: Num + Sized + Copy> AllowedUnit for T {}
pub trait AllowedUnit: Num + Sized + Copy + Debug + 'static {}
impl<T: Num + Sized + Copy + Debug + 'static> AllowedUnit for T {}

pub trait AllowedArray: AllowedUnit {}
impl<T: AllowedUnit> AllowedArray for T {}
Expand Down
Loading

0 comments on commit 97478fe

Please sign in to comment.