Skip to content

Commit

Permalink
Merge pull request opencog#2860 from linas/fast-math
Browse files Browse the repository at this point in the history
Fast math - speed up computation of dot products.
  • Loading branch information
linas authored Oct 14, 2021
2 parents e5e588b + ba83931 commit c6e5446
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 19 deletions.
4 changes: 2 additions & 2 deletions opencog/matrix/cosine.scm
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@
(let* ((star-obj (add-pair-stars LLOBJ))
(supp-obj (add-support-compute star-obj GET-CNT))
(prod-obj (add-support-compute
(add-tuple-math star-obj * GET-CNT)))
(add-fast-math star-obj * GET-CNT)))
(min-obj (add-support-compute
(add-tuple-math star-obj min GET-CNT)))
(max-obj (add-support-compute
(add-tuple-math star-obj max GET-CNT)))
(either-obj (add-support-compute
(add-tuple-math star-obj either GET-CNT)))
(both-obj (add-support-compute
(add-tuple-math star-obj both GET-CNT)))
(add-fast-math star-obj both GET-CNT)))
)

; -------------
Expand Down
40 changes: 36 additions & 4 deletions opencog/matrix/direct-sum.scm
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@
; (and so we can dispatch based on the types) or, if not,
; then that the atoms are in left/right basis, (and can thus
; be dispatched based on membership.) If this is not the case,
; this will return undefined results. This should be good enough
; to create the needed wild-cards.
(define (make-pair L-ATOM R-ATOM)
(init-a-base)
; then this attempts to handle the special case of one or the
; other being a VariableNode, and tries to do the right thing.
; If this is not the case, it will return undefined results.
; (which may be crazy.)
(define (make-ordinary-pair L-ATOM R-ATOM)
(if distinct-type
(if disjoint-left
(if (symb-comp? L-ATOM (LLA 'left-type))
Expand All @@ -230,6 +231,37 @@
(LLA 'make-pair L-ATOM R-ATOM)
(LLB 'make-pair L-ATOM R-ATOM))))

(define (make-left-variable-pair L-ATOM R-ATOM)
(if disjoint-right
(if (symb-comp? R-ATOM (LLA 'right-type))
(LLA 'make-pair L-ATOM R-ATOM)
(LLB 'make-pair L-ATOM R-ATOM))
(if (equal? (cog-type (get-pair-type)) 'TypeChoice)
(ChoiceLink
(LLA 'make-pair L-ATOM R-ATOM)
(LLB 'make-pair L-ATOM R-ATOM))
(throw 'invalid 'direct-sum "Don't know how to make this pair"))))

(define (make-right-variable-pair L-ATOM R-ATOM)
(if disjoint-left
(if (symb-comp? L-ATOM (LLA 'left-type))
(LLA 'make-pair L-ATOM R-ATOM)
(LLB 'make-pair L-ATOM R-ATOM))
(if (equal? (cog-type (get-pair-type)) 'TypeChoice)
(ChoiceLink
(LLA 'make-pair L-ATOM R-ATOM)
(LLB 'make-pair L-ATOM R-ATOM))
(throw 'invalid 'direct-sum "Don't know how to make this pair"))))

(define (make-pair L-ATOM R-ATOM)
(init-a-base)
(cond
((equal? (cog-type L-ATOM) 'VariableNode)
(make-left-variable-pair L-ATOM R-ATOM))
((equal? (cog-type R-ATOM) 'VariableNode)
(make-right-variable-pair L-ATOM R-ATOM))
(else (make-ordinary-pair L-ATOM R-ATOM))))

; Given a pair, find the left element in it.
(define (get-pair-left PAIR)
(init-a-set)
Expand Down
111 changes: 110 additions & 1 deletion opencog/matrix/fold-api.scm
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@

(use-modules (srfi srfi-1))
(use-modules (ice-9 optargs)) ; for define*-public
(use-modules (opencog))
(use-modules (opencog) (opencog exec))

; ---------------------------------------------------------------------
;
Expand Down Expand Up @@ -259,5 +259,114 @@
(else (apply LLOBJ (cons message args))))
)))

; ---------------------------------------------------------------------
;
(define*-public (add-fast-math LLOBJ FUNC #:optional
(GET-CNT 'get-count))
"
add-fast-math LLOBJ FUNC - Fast version of `add-tuple-math`
See `add-tuple-math` for details. This is much faster, as it uses
the pattern engine to find tuples. It is limited, though: it will
only return intersections! This is sufficient for taking products
but is not enough for sums.
For example, given two columns [y,z], the 'left-stars method
will return the set
{ [(x,y), (x,z)] | both (x,y) and (x,z) are present in
the atomspace. }
"
(let ((star-obj (add-pair-stars LLOBJ))
(get-cnt (lambda (x) (LLOBJ GET-CNT x)))
)
; ---------------
(define (thunk-type TY) (if (symbol? TY) (TypeNode TY) TY))

; ---------------
(define (left-star-intersect COL-TUPLE)
(define row-var (uniquely-named-variable)) ; shared rows
(define row-type (thunk-type (LLOBJ 'left-type)))
(define term-list
(map (lambda (COL) (LLOBJ 'make-pair row-var COL)) COL-TUPLE))

(define qry
(Meet
(TypedVariable row-var row-type)
(Present term-list)))

(define rowset (cog-value->list (cog-execute! qry)))

; Convert what the pattern engine returned to
; a list of scheme lists.
(map
(lambda (ROW)
(map (lambda (COL) (LLOBJ 'make-pair ROW COL)) COL-TUPLE))
rowset)
)

; ---------------
(define (right-star-intersect ROW-TUPLE)
(define col-var (uniquely-named-variable)) ; shared cols
(define col-type (thunk-type (LLOBJ 'right-type)))
(define term-list
(map (lambda (ROW) (LLOBJ 'make-pair ROW col-var)) ROW-TUPLE))

; XXX TODO -- It is MUCH faster to have the qry create the
; desired pairs for us. However, this is problematic for
; the direct-sum, where the top query term is a ChoiceLink,
; and there is no easy way to get which choice was made!
; So we pay a fairly hefty penalty runningthe map immediately
; below. It would be better to do it all in the pattern engine.
; Even better: run the FUNC on everything return that.
; Better still: run the FUNC on the query results, in the C++
; code, and not in scheme.
(define qry
(Meet
(TypedVariable col-var col-type)
(Present term-list)))

(define colset (cog-value->list (cog-execute! qry)))

; Convert what the pattern engine returned to
; a list of scheme lists.
(map
(lambda (COL)
(map (lambda (ROW) (LLOBJ 'make-pair ROW COL)) ROW-TUPLE))
colset)
)

; ---------------
; Given a TUPLE of pairs, return a single number.
; The FUNC is applied to reduce the counts on each pair
; in the tuple down to just one number.
(define (get-func-count TUPLE)
(apply FUNC
(map
(lambda (pr) (if (null? pr) 0 (get-cnt pr)))
TUPLE)))

; ---------------
; Return a pointer to each method that this class overloads.
(define (provides meth)
(case meth
((left-stars) left-star-intersect)
((right-stars) right-star-intersect)
((get-count) get-func-count)
(else (LLOBJ 'provides meth))))

; ---------------

; Methods on this class.
(lambda (message . args)
(case message
((left-stars) (apply left-star-intersect args))
((right-stars) (apply right-star-intersect args))
((get-count) (apply get-func-count args))
((provides) (apply provides args))
(else (apply LLOBJ (cons message args)))))
)
)

; ---------------------------------------------------------------------
; ---------------------------------------------------------------------
2 changes: 1 addition & 1 deletion opencog/matrix/symmetric-mi.scm
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
(sup-obj (add-support-api star-obj))
(trans-obj (add-transpose-api star-obj))
(prod-obj (add-support-compute
(add-tuple-math star-obj * GET-CNT)))
(add-fast-math star-obj * GET-CNT)))

; Cache of the totals
(mtm-total #f)
Expand Down
63 changes: 63 additions & 0 deletions tests/matrix/VectorAPIUTest.cxxtest
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class VectorAPIUTest : public CxxTest::TestSuite

void test_basic(void);
void test_marginals(void);
void test_fast(void);
void test_cosines(void);
void xtest_dynamic(void);
void test_fold(void);
Expand All @@ -88,6 +89,9 @@ class VectorAPIUTest : public CxxTest::TestSuite
TSM_ASSERT_LESS_THAN(MSG, atof(X.c_str()), VAL+EPS); \
TSM_ASSERT_LESS_THAN(MSG, VAL-EPS, atof(X.c_str()));

#define CHKCMP(MSG,X,Y) \
CHKFLT(MSG, X, atof(Y.c_str()))

/*
* This is called once before each test, for each test (!!)
*/
Expand Down Expand Up @@ -401,6 +405,65 @@ void VectorAPIUTest::test_marginals(void)

// ====================================================================

void VectorAPIUTest::test_fast(void)
{
logger().debug("BEGIN TEST: %s", __FUNCTION__);

std::string rc = eval->eval("(load-from-path \"tests/matrix/basic-data.scm\")");
printf("Load of data >>>%s", rc.c_str());
CHKERR;

std::string num, nut, nuf;
num = eval->eval("(cosi 'right-product (Word \"table\") (Word \"dog\"))");
nut = eval->eval("(prod-t 'right-count (list (Word \"table\") (Word \"dog\")))");
nuf = eval->eval("(prod-f 'right-count (list (Word \"table\") (Word \"dog\")))");

CHKCMP("tuple-product(table,dog)", num, nut);
CHKCMP("fast-product (table,dog)", num, nuf);
// --------------

num = eval->eval("(cosi 'right-product (Word \"chicken\") (Word \"dog\"))");
nut = eval->eval("(prod-t 'right-count (list (Word \"chicken\") (Word \"dog\")))");
nuf = eval->eval("(prod-f 'right-count (list (Word \"chicken\") (Word \"dog\")))");

CHKCMP("tuple-product(chicken,dog)", num, nut);
CHKCMP("fast-product (chicken,dog)", num, nuf);
// --------------

num = eval->eval("(cosi 'right-product (Word \"table\") (Word \"chicken\"))");
nut = eval->eval("(prod-t 'right-count (list (Word \"table\") (Word \"chicken\")))");
nuf = eval->eval("(prod-f 'right-count (list (Word \"table\") (Word \"chicken\")))");

CHKCMP("tuple-product(table,chicken)", num, nut);
CHKCMP("fast-product (table,chicken)", num, nuf);
// --------------

#define LEFT_PROD(A,B)\
num = eval->eval("(cosi 'left-product (Word \"" A "\") (Word \"" B "\"))");\
nut = eval->eval("(prod-t 'left-count (list (Word \"" A "\") (Word \"" B "\")))");\
nuf = eval->eval("(prod-f 'left-count (list (Word \"" A "\") (Word \"" B "\")))");\
CHKCMP("tuple-product(" A "," B ")", num, nut);\
CHKCMP("fast-product (" A "," B ")", num, nuf);

LEFT_PROD("wings", "wings")
LEFT_PROD("wings", "eyes")
LEFT_PROD("wings", "legs")
LEFT_PROD("wings", "snouts")

LEFT_PROD("eyes", "eyes")
LEFT_PROD("eyes", "legs")
LEFT_PROD("eyes", "snouts")

LEFT_PROD("legs", "legs")
LEFT_PROD("legs", "snouts")

LEFT_PROD("snouts", "snouts")

logger().debug("END TEST: %s", __FUNCTION__);
}

// ====================================================================

void VectorAPIUTest::test_cosines(void)
{
logger().debug("BEGIN TEST: %s", __FUNCTION__);
Expand Down
22 changes: 11 additions & 11 deletions tests/matrix/basic-api.scm
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
(define (get-count PAIR)
(cog-value-ref (cog-value PAIR (Predicate "counter")) 2))

; Return the observed count for the pair (L-ATOM, R-ATOM), if it
; exists, else return zero.
(define (get-pair-count L-ATOM R-ATOM)
(define stats-atom (get-pair L-ATOM R-ATOM))
(if (null? stats-atom) 0 (get-count stats-atom)))

; Return the atom holding the count, creating it if it does
; not yet exist. Returns the same structure as the 'item-pair
; method (the get-pair function, above).
Expand Down Expand Up @@ -88,16 +94,6 @@
(define (fetch-all-pairs)
(fetch-incoming-by-type (Predicate "foo") 'EvaluationLink))

; Tell the stars object what we provide.
(define (provides meth)
(case meth
((get-pair) get-pair)
((get-count) get-count)
((make-pair) make-pair)
((left-element) get-left-element)
((right-element) get-right-element)
(else #f)))

; Methods on the class. To call these, quote the method name.
; Example: (OBJ 'left-wildcard WORD) calls the
; get-left-wildcard function, passing WORD as the argument.
Expand All @@ -111,6 +107,7 @@
((left-type) get-left-type)
((right-type) get-right-type)
((pair-type) get-pair-type)
((pair-count) get-pair-count)
((get-pair) get-pair)
((get-count) get-count)
((make-pair) make-pair)
Expand All @@ -120,7 +117,7 @@
((right-wildcard) get-right-wildcard)
((wild-wild) get-wild-wild)
((fetch-pairs) fetch-all-pairs)
((provides) provides)
((provides) (lambda (symb) #f))
((filters?) (lambda () #f))
(else (error "Bad method call on low-level API")))
args)))
Expand Down Expand Up @@ -150,4 +147,7 @@

(define symc (add-symmetric-mi-compute bapi))

(define prod-t (add-support-compute (add-tuple-math sapi *)))
(define prod-f (add-support-compute (add-fast-math sapi *)))

*unspecified*

0 comments on commit c6e5446

Please sign in to comment.