Skip to content

Commit

Permalink
Add frombuffer to Torch frontend (ivy-llc#14109)
Browse files Browse the repository at this point in the history
Co-authored-by: Vansh Gupta <[email protected]>
  • Loading branch information
adityagandhamal and V-G-spec authored Apr 20, 2023
1 parent c6f7de4 commit 2623066
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
12 changes: 12 additions & 0 deletions ivy/functional/frontends/torch/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,15 @@ def asarray(
copy=None,
):
return ivy.asarray(obj, copy=copy, dtype=dtype, device=device)


@to_ivy_arrays_and_back
def frombuffer(
buffer,
*,
dtype,
count=-1,
offset=0,
requires_grad=False,
):
return ivy.frombuffer(buffer, dtype=dtype, count=count, offset=offset)
46 changes: 46 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ivy
from hypothesis import strategies as st, assume
import math
import numpy as np

# local
import ivy_tests.test_ivy.helpers as helpers
Expand Down Expand Up @@ -688,3 +689,48 @@ def test_torch_from_dlpack(
fn_tree=fn_tree,
on_device=on_device,
)


@st.composite
def _get_dtype_buffer_count_offset(draw):
dtype, value = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
)
)
value = np.array(value)
length = value.size
value = value.tobytes()

offset = draw(helpers.ints(min_value=0, max_value=length - 1))
count = draw(helpers.ints(min_value=-(2**30), max_value=length - offset))
if count == 0:
count = -1
offset = offset * np.dtype(dtype[0]).itemsize

return dtype, value, count, offset


@handle_frontend_test(
fn_tree="torch.frombuffer",
dtype_buffer_count_offset=_get_dtype_buffer_count_offset(),
)
def test_torch_frombuffer(
dtype_buffer_count_offset,
test_flags,
frontend,
fn_tree,
on_device,
):
input_dtype, buffer, count, offset = dtype_buffer_count_offset
helpers.test_frontend_function(
input_dtypes=input_dtype,
test_flags=test_flags,
on_device=on_device,
frontend=frontend,
fn_tree=fn_tree,
buffer=buffer,
dtype=input_dtype[0],
count=count,
offset=offset,
)

0 comments on commit 2623066

Please sign in to comment.