Skip to content

Commit

Permalink
Expose return statement results on Expression objects
Browse files Browse the repository at this point in the history
  • Loading branch information
jciskey committed May 17, 2020
1 parent 912620e commit 558d6c4
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
40 changes: 40 additions & 0 deletions cython/cexprtk/_cexprtk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cimport cexprtk_unknown_symbol_resolver
cimport cexprtk_util

from cpython.ref cimport Py_INCREF
from cython.operator cimport dereference

from libcpp.cast cimport dynamic_cast
from libcpp.vector cimport vector
Expand Down Expand Up @@ -220,6 +221,45 @@ cdef class Expression:
raise exception[0], exception[1], exception[2]
return v

def results(self):
"""Gets the results of evaluating the expression.
:return: A list of the results from evaluating the expression, if any.
:rtype: list"""
cdef exprtk.type_store_type ts
cdef exprtk.type_store[double].scalar_view * sv
cdef exprtk.type_store[double].string_view * string_view
cdef exprtk.type_store[double].vector_view * vector_view
cdef double x
cdef exprtk.results_context_type resultscontext = self._cexpressionptr.results()
ret_list = []
for i in range(resultscontext.count()):
ts = resultscontext[i]
if ts.type == 1:
sv = new exprtk.type_store[double].scalar_view(ts)
ret_list.append(sv.v_)
del sv
elif ts.type == 2:
# Get the vector and append it here
# ret_list.append(f'Type: {ts.type}')
vector_view = new exprtk.type_store[double].vector_view(ts)
sub_list = []
for i in range(vector_view.size()):
x = dereference(vector_view)[i]
sub_list.append(x)
ret_list.append(sub_list)
del vector_view
elif ts.type == 3:
# Get string and append it here
string_view = new exprtk.type_store[double].string_view(ts)
ret_list.append(exprtk.to_str(dereference(string_view)).decode('ascii'))
del string_view
# pass
else:
# No idea what type it is, ignore it
continue
return ret_list

def __call__(self):
"""Equivalent to calling value() method."""
return self.value()
Expand Down
24 changes: 24 additions & 0 deletions cython/exprtk.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,33 @@ cdef extern from "exprtk.hpp" namespace "exprtk":
ivararg_function[T]* get_vararg_function(string& vararg_function_name)
variable_ptr get_variable(string& variable_name)

cdef cppclass type_store[T]:
type_store() except +
int size
int type
cppclass scalar_view:
scalar_view(type_store[T]& ts) except +
T& v_
cppclass type_view[ViewType]:
type_view(type_store[T]& ts) except +
int size()
ViewType& operator[](int& i)
ViewType* data_
ctypedef type_view[T] vector_view
ctypedef type_view[char] string_view

cdef string to_str(type_store[double].string_view& view)

cdef cppclass results_context[T]:
results_context() except +
int count()
type_store[T] operator[](int& index)

cdef cppclass expression[T]:
expression() except +
void register_symbol_table(symbol_table[T])
T value()
results_context[T] results()

cdef cppclass parser[T]:
parser() except +
Expand All @@ -62,3 +84,5 @@ cdef extern from "exprtk.hpp" namespace "exprtk":
ctypedef symbol_table[double] symbol_table_type
ctypedef expression[double] expression_type
ctypedef parser[double] parser_type
ctypedef results_context[double] results_context_type
ctypedef type_store[double] type_store_type
33 changes: 33 additions & 0 deletions tests/test_cexprtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,39 @@ def testSymbolTableProperty(self):
v = expression()
self.assertAlmostEqual(18.0, v)

def testReturnResults(self):
"""Test that basic return calls in expressions work."""
st = cexprtk.Symbol_Table({}, {})
exp = "var x[2] := {1, 2}; return [4, 'abc', x];"
expression = cexprtk.Expression(exp, st)
v = expression.value()

results_list = expression.results()

self.assertEqual(3, len(results_list))

scalar_val = results_list[0]
self.assertAlmostEqual(4.0, scalar_val)

string_val = results_list[1]
self.assertEqual('abc', string_val)

vector_val = results_list[2]
self.assertEqual(2, len(vector_val))
self.assertAlmostEqual(1.0, vector_val[0])
self.assertAlmostEqual(2.0, vector_val[1])

def testResultsEmptyWithNoReturn(self):
"""Test that an expression has no results
when it doesn't include a return statement"""
st = cexprtk.Symbol_Table({},{})
expression = cexprtk.Expression("2+2", st)
v = expression.value()
self.assertAlmostEqual(4.0, v)

results_list = expression.results()
self.assertEqual(0, len(results_list))


class Symbol_TableVariablesTestCase(unittest.TestCase):
"""Tests for cexprtk._Symbol_Table_Variables"""
Expand Down

0 comments on commit 558d6c4

Please sign in to comment.