Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch frontend svds functions #28829

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed torch frontend test_blas_and_lapack_ops.test_torch_svd
fixed testing dtype range
making input symmetric positive definite matrix
conduct calculation value test externally as the results are supposed to be not unique
  • Loading branch information
Daniel4078 authored Sep 28, 2024
commit 36d219066817fbad4970f248c105dba9aa3aab17
Original file line number Diff line number Diff line change
Expand Up @@ -848,37 +848,75 @@ def test_torch_qr(
@handle_frontend_test(
fn_tree="torch.svd",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float", index=1),
min_num_dims=3,
max_num_dims=5,
min_dim_size=2,
max_dim_size=5,
available_dtypes=helpers.get_dtypes("valid"),
min_value=0,
max_value=10,
shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)),
),
some=st.booleans(),
compute=st.booleans(),
compute_uv=st.booleans(),
)
def test_torch_svd(
dtype_and_x,
some,
compute,
compute_uv,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=dtype,
input_dtype, x = dtype_and_x
x = np.asarray(x[0], dtype=input_dtype[0])
# make symmetric positive definite
x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3
ret, frontend_ret = helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
test_values=False,
input=x,
some=some,
compute_uv=compute,
)
compute_uv=compute_uv,
)
if backend_fw == "torch":
frontend_ret = [x.detach() for x in frontend_ret]
ret = [x.detach() for x in ret]
ret = [np.asarray(x) for x in ret]
frontend_ret = [
np.asarray(x.resolve_conj()).astype(input_dtype[0]) for x in frontend_ret
]
u, s, v = ret
frontend_u, frontend_s, frontend_v = frontend_ret
if not compute_uv:
helpers.assert_all_close(
ret_np=frontend_s,
ret_from_gt_np=s,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
elif not some:
helpers.assert_all_close(
ret_np=frontend_u @ np.diag(frontend_s) @ frontend_v.T,
ret_from_gt_np=u @ np.diag(s) @ v.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
else:
helpers.assert_all_close(
ret_np=frontend_u[..., : frontend_s.shape[0]]
@ np.diag(frontend_s)
@ frontend_v.T,
ret_from_gt_np=u[..., : s.shape[0]] @ np.diag(s) @ v.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)


@handle_frontend_test(
Expand Down