Skip to content

Commit

Permalink
Add versions of toTensor that take a type as their second argument (#…
Browse files Browse the repository at this point in the history
…652)

These are convenience functions which are equivalent to calling the regular `toTensor` followed by `asType`. This makes code that uses this relatively common idiom less verbose.
  • Loading branch information
AngelEzquerra authored May 23, 2024
1 parent 2c4f2cd commit 534ea98
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/arraymancer/laser/tensor/initialization.nim
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,24 @@ proc toTensor*[T](a: openArray[T]): auto =
let data = toSeq(flatIter(a))
result = toTensor(data, shape)

proc toTensor*[T; U](a: openArray[T], typ: typedesc[U]): Tensor[U] {.inline.} =
## Convert an openArray into a Tensor of type `typ`
##
## This is a convenience function which given an input `a` is equivalent to
## calling `a.toTensor().asType(typ)`. If `typ` is the same type of the
## elements of `a` then it is the same as `a.toTensor()` (i.e. there is no
## overhead).
##
## Inputs:
## - An array or a seq (can be nested)
## - The target type of the result Tensor
## Result:
## - A Tensor of the selected type and the same shape as the input
when T is U:
toTensor(a)
else:
toTensor(a).asType(typ)

proc toTensor*[T](a: SomeSet[T]): auto =
## Convert a HashSet or an OrderedSet into a Tensor
##
Expand All @@ -270,6 +288,24 @@ proc toTensor*[T](a: SomeSet[T]): auto =
let data = toSeq(a)
result = toTensor(data, shape)

proc toTensor*[T; U](a: SomeSet[T], typ: typedesc[U]): Tensor[U] {.inline.} =
## Convert a HashSet or an OrderedSet into a Tensor of type `typ`
##
## This is a convenience function which given an input `a` is equivalent to
## calling `a.toTensor().asType(typ)`. If `typ` is the same type of the
## elements of `a` then it is the same as `a.toTensor()` (i.e. there is no
## overhead).
##
## Inputs:
## - An HashSet or an OrderedSet
## - The target type of the result Tensor
## Result:
## - A Tensor of the selected type
when T is U:
toTensor(a)
else:
toTensor(a).asType(typ)

proc fromBuffer*[T](rawBuffer: ptr UncheckedArray[T], shape: varargs[int], layout: static OrderType): Tensor[T] =
## Creates a `Tensor[T]` from a raw buffer, cast as `ptr UncheckedArray[T]`. The
## size derived from the given shape must match the size of the buffer!
Expand Down
4 changes: 4 additions & 0 deletions tests/tensor/test_init.nim
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ proc main() =
else:
echo "Bound-checking is disabled or OpenMP is used. The incorrect seq shape test has been skipped."

# Call `toTensor` with a target type
let t5 = [1, -3, 4].toTensor(Complex64)
check t5 == [complex(1.0), complex(-3.0), complex(4.0)].toTensor

test "Check that Tensor shape is in row-by-column order":
let s = @[@[1,2,3],@[3,2,1]]
let t = s.toTensor()
Expand Down

0 comments on commit 534ea98

Please sign in to comment.