From f0cc9090559329519633a0a23d2fb03c4ce4edcc Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 10:58:34 -0400 Subject: [PATCH 01/26] Fix malformed comments that include directives to both mypy and pylint -- each needs a preceding comment character --- spatialmath/geom3d.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/spatialmath/geom3d.py b/spatialmath/geom3d.py index 8b191ebd..c941f9ee 100755 --- a/spatialmath/geom3d.py +++ b/spatialmath/geom3d.py @@ -714,7 +714,7 @@ def isintersecting( """ return not l1.isparallel(l2, tol=tol) and bool(abs(l1 * l2) < tol * _eps) - def __eq__(l1, l2: Line3) -> bool: # type: ignore pylint: disable=no-self-argument + def __eq__(l1, l2: Line3) -> bool: # type: ignore # pylint: disable=no-self-argument """ Test if two lines are equivalent @@ -733,7 +733,7 @@ def __eq__(l1, l2: Line3) -> bool: # type: ignore pylint: disable=no-self-argum """ return l1.isequal(l2) - def __ne__(l1, l2: Line3) -> bool: # type:ignore pylint: disable=no-self-argument + def __ne__(l1, l2: Line3) -> bool: # type:ignore # pylint: disable=no-self-argument """ Test if two lines are not equivalent @@ -752,7 +752,7 @@ def __ne__(l1, l2: Line3) -> bool: # type:ignore pylint: disable=no-self-argume """ return not l1.isequal(l2) - def __or__(l1, l2: Line3) -> bool: # type:ignore pylint: disable=no-self-argument + def __or__(l1, l2: Line3) -> bool: # type:ignore # pylint: disable=no-self-argument """ Overloaded ``|`` operator tests for parallelism @@ -771,7 +771,7 @@ def __or__(l1, l2: Line3) -> bool: # type:ignore pylint: disable=no-self-argume """ return l1.isparallel(l2) - def __xor__(l1, l2: Line3) -> bool: # type:ignore pylint: disable=no-self-argument + def __xor__(l1, l2: Line3) -> bool: # type:ignore # pylint: disable=no-self-argument """ Overloaded ``^`` operator tests for intersection @@ -989,7 +989,7 @@ def closest_to_point(self, x: ArrayLike3) -> Tuple[R3, float]: def commonperp( l1, l2: Line3 - ) -> Line3: # type:ignore pylint: disable=no-self-argument + ) -> Line3: # type:ignore # pylint: disable=no-self-argument """ Common perpendicular to two lines @@ -1019,7 +1019,7 @@ def commonperp( def __mul__( left, right: Line3 - ) -> float: # type:ignore pylint: disable=no-self-argument + ) -> float: # type:ignore # pylint: disable=no-self-argument r""" Reciprocal product @@ -1047,7 +1047,7 @@ def __mul__( def __rmul__( right, left: SE3 - ) -> Line3: # type:ignore pylint: disable=no-self-argument + ) -> Line3: # type:ignore # pylint: disable=no-self-argument """ Rigid-body transformation of 3D line From 8ab48e6cbd19449834acd600c11bb0d8f0348574 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 11:10:58 -0400 Subject: [PATCH 02/26] Add a dev dependency on a mypy version that fixes the "Positional-only parameters are only supported in Python 3.8 and greater" bug. --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 23168906..1d07e912 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,8 @@ dev = [ "pytest-timeout", "pytest-xvfb", "coverage", - "flake8" + "flake8", + "mypy>=0.981", ] docs = [ From cf176554a4b5ebfc9b695cb76e7e022462afce03 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 11:18:18 -0400 Subject: [PATCH 03/26] Fix python version info to the form mypy understands: direct references to sys.version_info, and explicit asserts in files to ignore. --- spatialmath/base/_types_311.py | 6 +++++- spatialmath/base/_types_39.py | 4 ++++ spatialmath/base/types.py | 7 ++----- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/spatialmath/base/_types_311.py b/spatialmath/base/_types_311.py index bd1d64b9..384d9e55 100644 --- a/spatialmath/base/_types_311.py +++ b/spatialmath/base/_types_311.py @@ -1,4 +1,8 @@ -# for Python >= 3.9 +# for Python >= 3.11 + +import sys + +assert sys.version_info >= (3, 11) from typing import ( overload, diff --git a/spatialmath/base/_types_39.py b/spatialmath/base/_types_39.py index 350210f5..30099099 100644 --- a/spatialmath/base/_types_39.py +++ b/spatialmath/base/_types_39.py @@ -1,5 +1,9 @@ # for Python >= 3.9 +import sys + +assert sys.version_info >= (3, 9) + from typing import ( overload, cast, diff --git a/spatialmath/base/types.py b/spatialmath/base/types.py index eb35e9d2..2167855a 100644 --- a/spatialmath/base/types.py +++ b/spatialmath/base/types.py @@ -1,11 +1,8 @@ import sys -_version = sys.version_info.minor - - -if _version >= 11: +if sys.version_info >= (3, 11): from spatialmath.base._types_311 import * -elif _version >= 9: +elif sys.version_info >= (3, 9): from spatialmath.base._types_39 import * else: from spatialmath.base._types_35 import * From 09186f139c283e78e9940b681b75a2c1f4bf44ae Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 11:28:37 -0400 Subject: [PATCH 04/26] Use typing.Literal to fix issues with signatures in argcheck that will never match. --- spatialmath/base/argcheck.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/spatialmath/base/argcheck.py b/spatialmath/base/argcheck.py index 40f94336..b0b4464b 100644 --- a/spatialmath/base/argcheck.py +++ b/spatialmath/base/argcheck.py @@ -14,6 +14,7 @@ import math import numpy as np from collections.abc import Iterable +from typing import Literal # from spatialmath.base import symbolic as sym # HACK from spatialmath.base.symbolic import issymbol, symtype @@ -281,7 +282,7 @@ def verifymatrix( def getvector( v: ArrayLike, dim: Optional[Union[int, None]] = None, - out: str = "array", + out: Literal["array"] = "array", dtype: DTypeLike = np.float64, ) -> NDArray: ... @@ -291,7 +292,7 @@ def getvector( def getvector( v: ArrayLike, dim: Optional[Union[int, None]] = None, - out: str = "list", + out: Literal["list"] = "list", dtype: DTypeLike = np.float64, ) -> List[float]: ... @@ -301,7 +302,7 @@ def getvector( def getvector( v: Tuple[float, ...], dim: Optional[Union[int, None]] = None, - out: str = "sequence", + out: Literal["sequence"] = "sequence", dtype: DTypeLike = np.float64, ) -> Tuple[float, ...]: ... From b4f9dee0179966120486772a1733ab95b288cbf9 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 13:09:11 -0400 Subject: [PATCH 05/26] Fixed most of the remaining overload-cannot-match errors with typing.Literal --- spatialmath/base/transforms2d.py | 9 +++++---- spatialmath/base/transforms3d.py | 21 +++++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/spatialmath/base/transforms2d.py b/spatialmath/base/transforms2d.py index 682ea0ca..c735006d 100644 --- a/spatialmath/base/transforms2d.py +++ b/spatialmath/base/transforms2d.py @@ -17,6 +17,7 @@ import sys import math import numpy as np +from typing import Literal try: import matplotlib.pyplot as plt @@ -443,7 +444,7 @@ def trinv2(T: SE2Array) -> SE2Array: @overload # pragma: no cover def trlog2( T: SO2Array, - twist: bool = False, + twist: Literal[False] = False, check: bool = True, tol: float = 20, ) -> so2Array: @@ -453,7 +454,7 @@ def trlog2( @overload # pragma: no cover def trlog2( T: SE2Array, - twist: bool = False, + twist: Literal[False] = False, check: bool = True, tol: float = 20, ) -> se2Array: @@ -463,7 +464,7 @@ def trlog2( @overload # pragma: no cover def trlog2( T: SO2Array, - twist: bool = True, + twist: Literal[True] = True, check: bool = True, tol: float = 20, ) -> float: @@ -473,7 +474,7 @@ def trlog2( @overload # pragma: no cover def trlog2( T: SE2Array, - twist: bool = True, + twist: Literal[True] = True, check: bool = True, tol: float = 20, ) -> R3: diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index 3617f965..60e92ba2 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -18,6 +18,7 @@ from collections.abc import Iterable import math import numpy as np +from typing import Literal from spatialmath.base.argcheck import getunit, getvector, isvector, isscalar, ismatrix from spatialmath.base.vectors import ( @@ -1271,25 +1272,25 @@ def tr2rpy( # ---------------------------------------------------------------------------------------# @overload # pragma: no cover def trlog( - T: SO3Array, check: bool = True, twist: bool = False, tol: float = 20 + T: SO3Array, check: bool = True, twist: Literal[False] = False, tol: float = 20 ) -> so3Array: ... @overload # pragma: no cover def trlog( - T: SE3Array, check: bool = True, twist: bool = False, tol: float = 20 + T: SE3Array, check: bool = True, twist: Literal[False] = False, tol: float = 20 ) -> se3Array: ... @overload # pragma: no cover -def trlog(T: SO3Array, check: bool = True, twist: bool = True, tol: float = 20) -> R3: +def trlog(T: SO3Array, check: bool = True, twist: Literal[True] = True, tol: float = 20) -> R3: ... @overload # pragma: no cover -def trlog(T: SE3Array, check: bool = True, twist: bool = True, tol: float = 20) -> R6: +def trlog(T: SE3Array, check: bool = True, twist: Literal[True] = True, tol: float = 20) -> R6: ... @@ -2222,7 +2223,7 @@ def angvelxform_dot(𝚪, 𝚪d, full=True, representation="rpy/xyz"): def rotvelxform( 𝚪: ArrayLike3, inverse: bool = False, - full: bool = False, + full: Literal[False] = False, representation="rpy/xyz", ) -> R3x3: ... @@ -2232,7 +2233,7 @@ def rotvelxform( def rotvelxform( 𝚪: SO3Array, inverse: bool = False, - full: bool = False, + full: Literal[False] = False, ) -> R3x3: ... @@ -2241,7 +2242,7 @@ def rotvelxform( def rotvelxform( 𝚪: ArrayLike3, inverse: bool = False, - full: bool = True, + full: Literal[True] = True, representation="rpy/xyz", ) -> R6x6: ... @@ -2251,7 +2252,7 @@ def rotvelxform( def rotvelxform( 𝚪: SO3Array, inverse: bool = False, - full: bool = True, + full: Literal[True] = True, ) -> R6x6: ... @@ -2464,14 +2465,14 @@ def rotvelxform( @overload # pragma: no cover def rotvelxform_inv_dot( - 𝚪: ArrayLike3, 𝚪d: ArrayLike3, full: bool = False, representation: str = "rpy/xyz" + 𝚪: ArrayLike3, 𝚪d: ArrayLike3, full: Literal[False] = False, representation: str = "rpy/xyz" ) -> R3x3: ... @overload # pragma: no cover def rotvelxform_inv_dot( - 𝚪: ArrayLike3, 𝚪d: ArrayLike3, full: bool = True, representation: str = "rpy/xyz" + 𝚪: ArrayLike3, 𝚪d: ArrayLike3, full: Literal[True] = True, representation: str = "rpy/xyz" ) -> R6x6: ... From a4cce4e3bdfd8ab2c138a7de5615d15214f23db5 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 13:11:58 -0400 Subject: [PATCH 06/26] Fixed overload-cannot-match: removed @overload for SO3.__init__ in favor of Union type (|) for arg --- spatialmath/pose3d.py | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index b4301d93..f7b357c5 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -53,31 +53,12 @@ class SO3(BasePoseMatrix): :parts: 1 """ - @overload - def __init__(self): - ... - - @overload - def __init__(self, arg: SO3, *, check=True): - ... - - @overload - def __init__(self, arg: SE3, *, check=True): - ... - - @overload - def __init__(self, arg: SO3Array, *, check=True): - ... - - @overload - def __init__(self, arg: List[SO3Array], *, check=True): - ... - - @overload - def __init__(self, arg: List[Union[SO3, SO3Array]], *, check=True): - ... - - def __init__(self, arg=None, *, check=True): + def __init__( + self, + arg : SO3 | SE3 | SO3Array | List[SO3 | SO3Array] = None, + *, + check=True, + ): """ Construct new SO(3) object From a1fe3c2dd5791a9cd8ff0eda5e1837d05ac903b1 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 13:14:24 -0400 Subject: [PATCH 07/26] Fixed overload-cannot-match: removed redundant @overload declarations for qisequal --- spatialmath/base/quaternions.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/spatialmath/base/quaternions.py b/spatialmath/base/quaternions.py index 8f33bc1c..46bba95b 100755 --- a/spatialmath/base/quaternions.py +++ b/spatialmath/base/quaternions.py @@ -173,27 +173,12 @@ def qisunit(q: ArrayLike4, tol: float = 20) -> bool: return smb.iszerovec(q, tol=tol) -@overload def qisequal( - q1: ArrayLike4, - q2: ArrayLike4, - tol: float = 20, - unitq: Optional[bool] = False, -) -> bool: - ... - - -@overload -def qisequal( - q1: ArrayLike4, - q2: ArrayLike4, - tol: float = 20, - unitq: Optional[bool] = True, -) -> bool: - ... - - -def qisequal(q1, q2, tol: float = 20, unitq: Optional[bool] = False): + q1: ArrayLike4, + q2: ArrayLike4, + tol: float = 20, + unitq: Optional[bool] = False, +): """ Test if quaternions are equal From f246c89aa7484e8d92021d02f5fd0714ec19a9d7 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 28 Oct 2024 13:27:14 -0400 Subject: [PATCH 08/26] Fixed all import-untyped errors (for missing types/stubs) by adding comment # type: ignore --- spatialmath/base/animate.py | 2 +- spatialmath/base/graphics.py | 14 +++++++------- spatialmath/base/quaternions.py | 2 +- spatialmath/base/symbolic.py | 2 +- spatialmath/base/transforms2d.py | 8 ++++---- spatialmath/base/transforms3d.py | 4 ++-- spatialmath/base/transformsNd.py | 2 +- spatialmath/base/vectors.py | 2 +- spatialmath/baseposematrix.py | 4 ++-- spatialmath/geom2d.py | 8 ++++---- spatialmath/geom3d.py | 2 +- spatialmath/spline.py | 6 +++--- spatialmath/timing.py | 2 +- 13 files changed, 29 insertions(+), 29 deletions(-) diff --git a/spatialmath/base/animate.py b/spatialmath/base/animate.py index a2e31f72..1f8b6c47 100755 --- a/spatialmath/base/animate.py +++ b/spatialmath/base/animate.py @@ -11,7 +11,7 @@ from __future__ import annotations import os.path import numpy as np -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt # type: ignore from matplotlib import animation import spatialmath.base as smb from collections.abc import Iterable, Iterator diff --git a/spatialmath/base/graphics.py b/spatialmath/base/graphics.py index 2ce18dc8..fa103496 100644 --- a/spatialmath/base/graphics.py +++ b/spatialmath/base/graphics.py @@ -2,7 +2,7 @@ from itertools import product import warnings import numpy as np -from matplotlib import colors +from matplotlib import colors # type: ignore from spatialmath import base as smb from spatialmath.base.types import * @@ -23,14 +23,14 @@ """ try: - import matplotlib.pyplot as plt - from matplotlib.patches import Circle - from mpl_toolkits.mplot3d.art3d import ( + import matplotlib.pyplot as plt # type: ignore + from matplotlib.patches import Circle # type: ignore + from mpl_toolkits.mplot3d.art3d import ( # type: ignore Poly3DCollection, Line3DCollection, pathpatch_2d_to_3d, ) - from mpl_toolkits.mplot3d import Axes3D + from mpl_toolkits.mplot3d import Axes3D # type: ignore # TODO # return a redrawer object, that can be used for animation @@ -796,13 +796,13 @@ def ellipse( so to avoid inverting ``E`` twice to compute the ellipse, we flag that the inverse is provided using ``inverted``. """ - from scipy.linalg import sqrtm + from scipy.linalg import sqrtm # type: ignore if E.shape != (2, 2): raise ValueError("ellipse is defined by a 2x2 matrix") if confidence: - from scipy.stats.distributions import chi2 + from scipy.stats.distributions import chi2 # type: ignore # process the probability s = math.sqrt(chi2.ppf(confidence, df=2)) * scale diff --git a/spatialmath/base/quaternions.py b/spatialmath/base/quaternions.py index 46bba95b..940a64c8 100755 --- a/spatialmath/base/quaternions.py +++ b/spatialmath/base/quaternions.py @@ -16,7 +16,7 @@ import spatialmath.base as smb from spatialmath.base.argcheck import getunit from spatialmath.base.types import * -import scipy.interpolate as interpolate +import scipy.interpolate as interpolate # type: ignore from typing import Optional from functools import lru_cache diff --git a/spatialmath/base/symbolic.py b/spatialmath/base/symbolic.py index 2d92f4d4..34a961ad 100644 --- a/spatialmath/base/symbolic.py +++ b/spatialmath/base/symbolic.py @@ -16,7 +16,7 @@ try: # pragma: no cover # print('Using SymPy') - import sympy + import sympy # type: ignore _symbolics = True symtype = (sympy.Expr,) diff --git a/spatialmath/base/transforms2d.py b/spatialmath/base/transforms2d.py index c735006d..7d83371d 100644 --- a/spatialmath/base/transforms2d.py +++ b/spatialmath/base/transforms2d.py @@ -20,7 +20,7 @@ from typing import Literal try: - import matplotlib.pyplot as plt + import matplotlib.pyplot as plt # type: ignore _matplotlib_exists = True except ImportError: @@ -35,7 +35,7 @@ try: # pragma: no cover # print('Using SymPy') - import sympy + import sympy # type: ignore _symbolics = True @@ -1132,7 +1132,7 @@ def ICP2d( # hack below to use points2tr above # use ClayFlannigan's improved data association - from scipy.spatial import KDTree + from scipy.spatial import KDTree # type: ignore def _FindCorrespondences( tree, source, reference @@ -1227,7 +1227,7 @@ def _FindCorrespondences( import matplotlib.pyplot as plt # from mpl_toolkits.axisartist import Axes - from matplotlib.axes import Axes + from matplotlib.axes import Axes # type: ignore def trplot2( T: Union[SO2Array, SE2Array], diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index 60e92ba2..5437c63c 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -2909,8 +2909,8 @@ def _vec2s(fmt, v): try: - import matplotlib.pyplot as plt - from mpl_toolkits.mplot3d import Axes3D + import matplotlib.pyplot as plt # type: ignore + from mpl_toolkits.mplot3d import Axes3D # type: ignore _matplotlib_exists = True except ImportError: diff --git a/spatialmath/base/transformsNd.py b/spatialmath/base/transformsNd.py index c04e5d8b..78d5841f 100644 --- a/spatialmath/base/transformsNd.py +++ b/spatialmath/base/transformsNd.py @@ -24,7 +24,7 @@ try: # pragma: no cover # print('Using SymPy') - from sympy import Matrix + from sympy import Matrix # type: ignore _symbolics = True diff --git a/spatialmath/base/vectors.py b/spatialmath/base/vectors.py index f29740a3..dffca21b 100644 --- a/spatialmath/base/vectors.py +++ b/spatialmath/base/vectors.py @@ -18,7 +18,7 @@ try: # pragma: no cover # print('Using SymPy') - import sympy + import sympy # type: ignore _symbolics = True diff --git a/spatialmath/baseposematrix.py b/spatialmath/baseposematrix.py index 1a850600..54e8b1e3 100644 --- a/spatialmath/baseposematrix.py +++ b/spatialmath/baseposematrix.py @@ -23,7 +23,7 @@ # colored printing of matrices to the terminal # colored package has much finer control than colorama, but the latter is available by default with anaconda try: - from colored import fg, bg, attr + from colored import fg, bg, attr # type: ignore _colored = True # print('using colored output') @@ -35,7 +35,7 @@ _colored = False try: - from ansitable import ANSIMatrix + from ansitable import ANSIMatrix # type: ignore _ANSIMatrix = True # print('using colored output') diff --git a/spatialmath/geom2d.py b/spatialmath/geom2d.py index 55eccb2a..b38a0115 100755 --- a/spatialmath/geom2d.py +++ b/spatialmath/geom2d.py @@ -9,10 +9,10 @@ from functools import reduce import warnings -import matplotlib.pyplot as plt -from matplotlib.path import Path -from matplotlib.patches import PathPatch -from matplotlib.transforms import Affine2D +import matplotlib.pyplot as plt # type: ignore +from matplotlib.path import Path # type: ignore +from matplotlib.patches import PathPatch # type: ignore +from matplotlib.transforms import Affine2D # type: ignore import numpy as np from spatialmath import SE2 diff --git a/spatialmath/geom3d.py b/spatialmath/geom3d.py index c941f9ee..a109861c 100755 --- a/spatialmath/geom3d.py +++ b/spatialmath/geom3d.py @@ -6,7 +6,7 @@ import numpy as np import math from collections import namedtuple -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt # type: ignore import spatialmath.base as base from spatialmath.base.types import * from spatialmath.baseposelist import BasePoseList diff --git a/spatialmath/spline.py b/spatialmath/spline.py index 0a472ecc..428a8143 100644 --- a/spatialmath/spline.py +++ b/spatialmath/spline.py @@ -9,10 +9,10 @@ from functools import cached_property from typing import List, Optional, Tuple, Set -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt # type: ignore import numpy as np -from scipy.interpolate import BSpline, CubicSpline -from scipy.spatial.transform import Rotation, RotationSpline +from scipy.interpolate import BSpline, CubicSpline # type: ignore +from scipy.spatial.transform import Rotation, RotationSpline # type: ignore from spatialmath import SE3, SO3, Twist3 from spatialmath.base.transforms3d import tranimate diff --git a/spatialmath/timing.py b/spatialmath/timing.py index ae169909..82e96c9d 100755 --- a/spatialmath/timing.py +++ b/spatialmath/timing.py @@ -8,7 +8,7 @@ import timeit -from ansitable import ANSITable, Column +from ansitable import ANSITable, Column # type: ignore N = 100000 From f1c97305b587475f97a2d8290da986026200e2bb Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 29 Oct 2024 10:45:39 -0400 Subject: [PATCH 09/26] PEP 484 prohibits implicit Optional; make the type explicit when the default is None. --- spatialmath/DualQuaternion.py | 4 +++- spatialmath/base/graphics.py | 6 +++--- spatialmath/baseposematrix.py | 4 ++-- spatialmath/pose3d.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/spatialmath/DualQuaternion.py b/spatialmath/DualQuaternion.py index 3b945d7c..3c8bcd1b 100644 --- a/spatialmath/DualQuaternion.py +++ b/spatialmath/DualQuaternion.py @@ -1,5 +1,7 @@ from __future__ import annotations import numpy as np +from typing import Optional + from spatialmath import Quaternion, UnitQuaternion, SE3 from spatialmath import base from spatialmath.base.types import * @@ -30,7 +32,7 @@ class DualQuaternion: :seealso: :func:`UnitDualQuaternion` """ - def __init__(self, real: Quaternion = None, dual: Quaternion = None): + def __init__(self, real: Optional[Quaternion] = None, dual: Optional[Quaternion] = None): """ Construct a new dual quaternion diff --git a/spatialmath/base/graphics.py b/spatialmath/base/graphics.py index fa103496..f04d2cd5 100644 --- a/spatialmath/base/graphics.py +++ b/spatialmath/base/graphics.py @@ -1570,7 +1570,7 @@ def axes_logic( return ax def plotvol2( - dim: ArrayLike = None, + dim: Optional[ArrayLike] = None, ax: Optional[plt.Axes] = None, equal: Optional[bool] = True, grid: Optional[bool] = False, @@ -1625,7 +1625,7 @@ def plotvol2( return ax def plotvol3( - dim: ArrayLike = None, + dim: Optional[ArrayLike] = None, ax: Optional[plt.Axes] = None, equal: Optional[bool] = True, grid: Optional[bool] = False, @@ -1685,7 +1685,7 @@ def plotvol3( ax._plotvol = True return ax - def expand_dims(dim: ArrayLike = None, nd: int = 2) -> NDArray: + def expand_dims(dim: Optional[ArrayLike] = None, nd: int = 2) -> NDArray: """ Expand compact axis dimensions diff --git a/spatialmath/baseposematrix.py b/spatialmath/baseposematrix.py index 54e8b1e3..66116504 100644 --- a/spatialmath/baseposematrix.py +++ b/spatialmath/baseposematrix.py @@ -377,7 +377,7 @@ def log(self, twist: Optional[bool] = False) -> Union[NDArray, List[NDArray]]: else: return log - def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shortest: bool = True) -> Self: + def interp(self, end: Optional[bool] = None, s: Union[int, float, None] = None, shortest: bool = True) -> Self: """ Interpolate between poses (superclass method) @@ -443,7 +443,7 @@ def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shorte [smb.trinterp(start=self.A, end=end, s=_s, shortest=shortest) for _s in s] ) - def interp1(self, s: float = None) -> Self: + def interp1(self, s: Optional[float] = None) -> Self: """ Interpolate pose (superclass method) diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index f7b357c5..3e8e62cb 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -55,7 +55,7 @@ class SO3(BasePoseMatrix): def __init__( self, - arg : SO3 | SE3 | SO3Array | List[SO3 | SO3Array] = None, + arg : None | SO3 | SE3 | SO3Array | List[SO3 | SO3Array] = None, *, check=True, ): From 15dcae41c99636d4d262545e745e547a0349301a Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 29 Oct 2024 12:05:40 -0400 Subject: [PATCH 10/26] Remove the default for real from DualQuaternion, and do not support None. Note that this changes the interface, but no examples call DualQuaternion() without the real part. --- spatialmath/DualQuaternion.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/spatialmath/DualQuaternion.py b/spatialmath/DualQuaternion.py index 3c8bcd1b..95801c27 100644 --- a/spatialmath/DualQuaternion.py +++ b/spatialmath/DualQuaternion.py @@ -32,7 +32,7 @@ class DualQuaternion: :seealso: :func:`UnitDualQuaternion` """ - def __init__(self, real: Optional[Quaternion] = None, dual: Optional[Quaternion] = None): + def __init__(self, real: Quaternion, dual: Optional[Quaternion] = None): """ Construct a new dual quaternion @@ -57,14 +57,10 @@ def __init__(self, real: Optional[Quaternion] = None, dual: Optional[Quaternion] """ - if real is None and dual is None: - self.real = None - self.dual = None - return - elif dual is None and base.isvector(real, 8): + if dual is None and base.isvector(real, 8): self.real = Quaternion(real[0:4]) self.dual = Quaternion(real[4:8]) - elif real is not None and dual is not None: + elif dual is not None: if not isinstance(real, Quaternion): raise ValueError("real part must be a Quaternion subclass") if not isinstance(dual, Quaternion): @@ -72,7 +68,7 @@ def __init__(self, real: Optional[Quaternion] = None, dual: Optional[Quaternion] self.real = real # quaternion, real part self.dual = dual # quaternion, dual part else: - raise ValueError("expecting zero or two parameters") + raise ValueError("expecting one or two parameters") @classmethod def Pure(cls, x: ArrayLike3) -> Self: From d4dec05f1e2fa7c4d5e04b50f62413bdd4e841d8 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 29 Oct 2024 12:21:56 -0400 Subject: [PATCH 11/26] Use DualQuaternion, not Self, as a return type in various methods of DualQuaternion, since they invoke the constructor directly (instead of, say, self.__class__) --- spatialmath/DualQuaternion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spatialmath/DualQuaternion.py b/spatialmath/DualQuaternion.py index 95801c27..d6bf0189 100644 --- a/spatialmath/DualQuaternion.py +++ b/spatialmath/DualQuaternion.py @@ -117,7 +117,7 @@ def norm(self) -> Tuple[float, float]: b = self.real * self.dual.conj() + self.dual * self.real.conj() return (base.sqrt(a.s), base.sqrt(b.s)) - def conj(self) -> Self: + def conj(self) -> DualQuaternion: r""" Conjugate of dual quaternion @@ -140,7 +140,7 @@ def conj(self) -> Self: def __add__( left, right: DualQuaternion - ) -> Self: # pylint: disable=no-self-argument + ) -> DualQuaternion: # pylint: disable=no-self-argument """ Sum of two dual quaternions @@ -159,7 +159,7 @@ def __add__( def __sub__( left, right: DualQuaternion - ) -> Self: # pylint: disable=no-self-argument + ) -> DualQuaternion: # pylint: disable=no-self-argument """ Difference of two dual quaternions @@ -176,7 +176,7 @@ def __sub__( """ return DualQuaternion(left.real - right.real, left.dual - right.dual) - def __mul__(left, right: Self) -> Self: # pylint: disable=no-self-argument + def __mul__(left, right: Self) -> DualQuaternion: # pylint: disable=no-self-argument """ Product of dual quaternion From a83f20d5977ad6e8d92f14c6c8ff92bedd9bfd0d Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 29 Oct 2024 12:26:26 -0400 Subject: [PATCH 12/26] Also remove the possibility of real=None for UnitDualQuaternion --- spatialmath/DualQuaternion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spatialmath/DualQuaternion.py b/spatialmath/DualQuaternion.py index d6bf0189..1ea7862f 100644 --- a/spatialmath/DualQuaternion.py +++ b/spatialmath/DualQuaternion.py @@ -269,13 +269,14 @@ class UnitDualQuaternion(DualQuaternion): """ @overload - def __init__(self, T: SE3): + def __init__(self, real: SE3, dual: None = None): ... + @overload def __init__(self, real: Quaternion, dual: Quaternion): ... - def __init__(self, real=None, dual=None): + def __init__(self, real, dual=None): r""" Create new unit dual quaternion From a2630c8f332948b4def48bb8f7939c0a4764da0b Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 29 Oct 2024 16:46:20 -0400 Subject: [PATCH 13/26] in binary methods of BasePoseList: - don't narrow the input type - don't change the return type - raise, don't return, NotImplementedError --- spatialmath/baseposelist.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/spatialmath/baseposelist.py b/spatialmath/baseposelist.py index d729b902..211aedb0 100644 --- a/spatialmath/baseposelist.py +++ b/spatialmath/baseposelist.py @@ -362,17 +362,17 @@ def __setitem__(self, i: int, value: BasePoseList) -> None: self.data[i] = value.A # flag these binary operators as being not supported - def __lt__(self, other: BasePoseList) -> Type[Exception]: - return NotImplementedError + def __lt__(self, other): + raise NotImplementedError() - def __le__(self, other: BasePoseList) -> Type[Exception]: - return NotImplementedError + def __le__(self, other): + raise NotImplementedError() - def __gt__(self, other: BasePoseList) -> Type[Exception]: - return NotImplementedError + def __gt__(self, other): + raise NotImplementedError() - def __ge__(self, other: BasePoseList) -> Type[Exception]: - return NotImplementedError + def __ge__(self, other): + raise NotImplementedError() def append(self, item: BasePoseList) -> None: """ @@ -674,4 +674,4 @@ def unop( print(R.eulervec()) R = SO2([0.3, 0.4, 0.5]) - pass \ No newline at end of file + pass From 19af4ac165486b89a5a6d9559fc396cddd5edd5f Mon Sep 17 00:00:00 2001 From: John Barnett Date: Wed, 27 Nov 2024 15:29:59 -0500 Subject: [PATCH 14/26] abstractstaticmethod was deprecated in python 3.3, and mypy doesn't like it; use staticmethod and abstractmethod instead. --- spatialmath/baseposelist.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/spatialmath/baseposelist.py b/spatialmath/baseposelist.py index 211aedb0..de1b4c52 100644 --- a/spatialmath/baseposelist.py +++ b/spatialmath/baseposelist.py @@ -5,7 +5,7 @@ # pylint: disable=invalid-name from __future__ import annotations from collections import UserList -from abc import ABC, abstractproperty, abstractstaticmethod +from abc import ABC, abstractproperty, abstractmethod import copy import numpy as np from spatialmath.base.argcheck import isnumberlist, isscalar @@ -70,11 +70,12 @@ def shape(self): pass @staticmethod - @abstractstaticmethod + @abstractmethod def isvalid(x, check=True): pass - @abstractstaticmethod + @staticmethod + @abstractmethod def _identity(): pass From 0d776196ce273fb5e39de3d61cde034fa2763756 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 2 Dec 2024 13:13:29 -0500 Subject: [PATCH 15/26] Declare control_poses as a list in the supertype --- spatialmath/spline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spatialmath/spline.py b/spatialmath/spline.py index 428a8143..c19ec2b2 100644 --- a/spatialmath/spline.py +++ b/spatialmath/spline.py @@ -20,7 +20,7 @@ class SplineSE3(ABC): def __init__(self) -> None: - self.control_poses: SE3 + self.control_poses: List[SE3] @abstractmethod def __call__(self, t: float) -> SE3: From 32ee31e7a105480f9c824ca18077dd3ae7d2d470 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Mon, 2 Dec 2024 14:13:07 -0500 Subject: [PATCH 16/26] * args are lists --- spatialmath/base/graphics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spatialmath/base/graphics.py b/spatialmath/base/graphics.py index f04d2cd5..87c83595 100644 --- a/spatialmath/base/graphics.py +++ b/spatialmath/base/graphics.py @@ -327,7 +327,7 @@ def plot_homline( return handles def plot_box( - *fmt: Optional[str], + *fmt: List[str], lbrt: Optional[ArrayLike4] = None, lrbt: Optional[ArrayLike4] = None, lbwh: Optional[ArrayLike4] = None, From 7f1c10778abffa8ae873a6631beec342d714bd1c Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 3 Dec 2024 10:56:57 -0500 Subject: [PATCH 17/26] Resolve some override type errors --- spatialmath/baseposelist.py | 2 +- spatialmath/baseposematrix.py | 2 +- spatialmath/pose3d.py | 8 ++++++++ spatialmath/quaternion.py | 26 +++++++++++++------------- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/spatialmath/baseposelist.py b/spatialmath/baseposelist.py index de1b4c52..509443c3 100644 --- a/spatialmath/baseposelist.py +++ b/spatialmath/baseposelist.py @@ -71,7 +71,7 @@ def shape(self): @staticmethod @abstractmethod - def isvalid(x, check=True): + def isvalid(x, check: bool=True) -> bool: pass @staticmethod diff --git a/spatialmath/baseposematrix.py b/spatialmath/baseposematrix.py index 66116504..60ed6d09 100644 --- a/spatialmath/baseposematrix.py +++ b/spatialmath/baseposematrix.py @@ -1572,7 +1572,7 @@ def __isub__(left, right: Self): # pylint: disable=no-self-argument """ return left.__sub__(right) - def __eq__(left, right: Self) -> bool: # pylint: disable=no-self-argument + def __eq__(left, right): # pylint: disable=no-self-argument """ Overloaded ``==`` operator (superclass method) diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index 3e8e62cb..cf33230b 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -464,6 +464,10 @@ def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str """ return cls([smb.q2r(smb.qrand(theta_range=theta_range, unit=unit)) for _ in range(0, N)], check=False) + @overload + def Eul(cls, phi: float, theta: float, psi: float, unit: str = "rad") -> SE3: + ... + @overload @classmethod def Eul(cls, *angles: float, unit: str = "rad") -> Self: @@ -514,6 +518,10 @@ def Eul(cls, *angles, unit: str = "rad") -> Self: else: return cls([smb.eul2r(a, unit=unit) for a in angles], check=False) + @overload + def RPY(cls, roll: float, pitch: float, yaw: float, unit: str = "rad") -> SE3: + ... + @overload @classmethod def RPY( diff --git a/spatialmath/quaternion.py b/spatialmath/quaternion.py index 51561036..0b458dab 100644 --- a/spatialmath/quaternion.py +++ b/spatialmath/quaternion.py @@ -131,7 +131,7 @@ def shape(self) -> Tuple[int]: return (4,) @staticmethod - def isvalid(x: ArrayLike4) -> bool: + def isvalid(x, check: bool=True): """ Test if vector is valid quaternion @@ -473,7 +473,7 @@ def inner(self, other) -> float: # -------------------------------------------- operators def __eq__( - left, right: Quaternion + left, right ) -> bool: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``==`` operator @@ -501,7 +501,7 @@ def __eq__( return left.binop(right, smb.qisequal, list1=False) def __ne__( - left, right: Quaternion + left, right ) -> bool: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``!=`` operator @@ -528,7 +528,7 @@ def __ne__( return left.binop(right, lambda x, y: not smb.qisequal(x, y), list1=False) def __mul__( - left, right: Quaternion + left, right ) -> Quaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``*`` operator @@ -591,7 +591,7 @@ def __mul__( raise ValueError("operands to * are of different types") def __rmul__( - right, left: Quaternion + right, left ) -> Quaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``*`` operator @@ -616,8 +616,8 @@ def __rmul__( return Quaternion([left * q._A for q in right]) def __imul__( - left, right: Quaternion - ) -> bool: # lgtm[py/not-named-self] pylint: disable=no-self-argument + left, right + ): # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``*=`` operator @@ -1570,7 +1570,7 @@ def dotb(self, omega: ArrayLike3) -> R4: return smb.qdotb(self._A, omega) def __mul__( - left, right: UnitQuaternion + left, right ) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Multiply unit quaternion @@ -1668,7 +1668,7 @@ def __mul__( raise ValueError("UnitQuaternion: operands to * are of different types") def __imul__( - left, right: UnitQuaternion + left, right ) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Multiply unit quaternion in place @@ -1692,7 +1692,7 @@ def __imul__( return left.__mul__(right) def __truediv__( - left, right: UnitQuaternion + left, right ) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``/`` operator @@ -1756,7 +1756,7 @@ def __truediv__( raise ValueError("bad operands") def __eq__( - left, right: UnitQuaternion + left, right ) -> bool: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``==`` operator @@ -1787,7 +1787,7 @@ def __eq__( ) def __ne__( - left, right: UnitQuaternion + left, right ) -> bool: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded ``!=`` operator @@ -1818,7 +1818,7 @@ def __ne__( ) def __matmul__( - left, right: UnitQuaternion + left, right ) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument """ Overloaded @ operator From bc3ab5514ec4bc14f5d9349c49717316d83b39c2 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 3 Dec 2024 11:29:23 -0500 Subject: [PATCH 18/26] Resolve override typing error for angdist by using Self --- spatialmath/pose3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index cf33230b..0bcba62b 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -853,7 +853,7 @@ def UnitQuaternion(self) -> UnitQuaternion: return UnitQuaternion(smb.r2q(self.R), check=False) - def angdist(self, other: SO3, metric: int = 6) -> Union[float, ndarray]: + def angdist(self, other: Self, metric: int = 6) -> Union[float, ndarray]: r""" Angular distance metric between rotations @@ -2009,7 +2009,7 @@ def CopyFrom( raise ValueError("Transformation matrix must not be None") return cls(np.copy(T), check=check) - def angdist(self, other: SE3, metric: int = 6) -> float: + def angdist(self, other: Self, metric: int = 6) -> float: r""" Angular distance metric between poses From 04b9fe7240c65d063e7dc6958432a203c4599f6e Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 3 Dec 2024 11:29:50 -0500 Subject: [PATCH 19/26] added TODO comment for possible resolution of override typing error of Exp --- spatialmath/pose3d.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index 0bcba62b..65beb6a2 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -803,6 +803,8 @@ def Exp( r""" Create an SO(3) rotation matrix from so(3) + TODO: clean up typing for S, possibly by introducing additional type parameters at the class level. + :param S: Lie algebra so(3) :type S: ndarray(3,3), ndarray(n,3) :param check: check that passed matrix is valid so(3), default True @@ -1805,6 +1807,8 @@ def Exp(cls, S: Union[R6, R4x4], check: bool = True) -> SE3: """ Create an SE(3) matrix from se(3) + TODO: clean up typing for S, possibly by introducing additional type parameters at the class level. + :param S: Lie algebra se(3) matrix :type S: ndarray(6), ndarray(4,4) :return: SE(3) matrix From 165614d933c777560e20b21279a148350b046c5c Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 3 Dec 2024 11:37:31 -0500 Subject: [PATCH 20/26] Consistently use @classmethod along with @overload'ed class methods. --- spatialmath/pose3d.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index 65beb6a2..40bfac59 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -465,6 +465,7 @@ def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str return cls([smb.q2r(smb.qrand(theta_range=theta_range, unit=unit)) for _ in range(0, N)], check=False) @overload + @classmethod def Eul(cls, phi: float, theta: float, psi: float, unit: str = "rad") -> SE3: ... @@ -519,6 +520,7 @@ def Eul(cls, *angles, unit: str = "rad") -> Self: return cls([smb.eul2r(a, unit=unit) for a in angles], check=False) @overload + @classmethod def RPY(cls, roll: float, pitch: float, yaw: float, unit: str = "rad") -> SE3: ... @@ -1565,10 +1567,12 @@ def Rand( ) @overload + @classmethod def Eul(cls, phi: float, theta: float, psi: float, unit: str = "rad") -> SE3: ... @overload + @classmethod def Eul(cls, angles: ArrayLike3, unit: str = "rad") -> SE3: ... @@ -1615,10 +1619,12 @@ def Eul(cls, *angles, unit="rad") -> SE3: return cls([smb.eul2tr(a, unit=unit) for a in angles], check=False) @overload + @classmethod def RPY(cls, roll: float, pitch: float, yaw: float, unit: str = "rad") -> SE3: ... @overload + @classmethod def RPY(cls, angles: ArrayLike3, unit: str = "rad") -> SE3: ... @@ -1847,10 +1853,12 @@ def Delta(cls, d: ArrayLike6) -> SE3: return cls(smb.trnorm(smb.delta2tr(d))) @overload + @classmethod def Trans(cls, x: float, y: float, z: float) -> SE3: ... @overload + @classmethod def Trans(cls, xyz: ArrayLike3) -> SE3: ... From ecd71d1ce1d2a209ed321e6a492be3123b077c5a Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 3 Dec 2024 17:04:29 -0500 Subject: [PATCH 21/26] Resolve [assignment] mypy error for symbolic.py with additional annotation --- spatialmath/base/symbolic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spatialmath/base/symbolic.py b/spatialmath/base/symbolic.py index 34a961ad..3f03e6f5 100644 --- a/spatialmath/base/symbolic.py +++ b/spatialmath/base/symbolic.py @@ -19,7 +19,7 @@ import sympy # type: ignore _symbolics = True - symtype = (sympy.Expr,) + symtype : Tuple[type, ...] = (sympy.Expr,) from sympy import Symbol except ImportError: # pragma: no cover From 08cc749b54ad56da8317e4f8b820b3a51999b889 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Tue, 3 Dec 2024 17:16:59 -0500 Subject: [PATCH 22/26] Update type aliases using already defined aliases; eliminates a handful of mypy errors --- spatialmath/base/_types_39.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/spatialmath/base/_types_39.py b/spatialmath/base/_types_39.py index 30099099..ab5a0a63 100644 --- a/spatialmath/base/_types_39.py +++ b/spatialmath/base/_types_39.py @@ -120,36 +120,21 @@ RNx3 = NDArray # Lie group elements -SO2Array = ndarray[Tuple[L[2, 2]], dtype[floating]] # SO(2) rotation matrix -SE2Array = ndarray[Tuple[L[3, 3]], dtype[floating]] # SE(2) rigid-body transform -# SO3Array = ndarray[Tuple[L[3, 3]], dtype[floating]] -SO3Array = np.ndarray[Tuple[L[3], L[3]], dtype[floating]] # SO(3) rotation matrix -SE3Array = ndarray[Tuple[L[4], L[4]], dtype[floating]] # SE(3) rigid-body transform +SO2Array = R2x2 # SO(2) rotation matrix +SE2Array = R3x3 # SE(2) rigid-body transform +SO3Array = R3x3 # SO(3) rotation matrix +SE3Array = R4x4 # SE(3) rigid-body transform # Lie algebra elements -so2Array = ndarray[ - Tuple[L[2, 2]], dtype[floating] -] # so(2) Lie algebra of SO(2), skew-symmetrix matrix -se2Array = ndarray[ - Tuple[L[3, 3]], dtype[floating] -] # se(2) Lie algebra of SE(2), augmented skew-symmetrix matrix -so3Array = ndarray[ - Tuple[L[3, 3]], dtype[floating] -] # so(3) Lie algebra of SO(3), skew-symmetrix matrix -se3Array = ndarray[ - Tuple[L[4, 4]], dtype[floating] -] # se(3) Lie algebra of SE(3), augmented skew-symmetrix matrix +so2Array = R2x2 # so(2) Lie algebra of SO(2), skew-symmetrix matrix +se2Array = R3x3 # se(2) Lie algebra of SE(2), augmented skew-symmetrix matrix +so3Array = R3x3 # so(3) Lie algebra of SO(3), skew-symmetrix matrix +se3Array = R4x4 # se(3) Lie algebra of SE(3), augmented skew-symmetrix matrix # quaternion arrays -QuaternionArray = ndarray[ - Tuple[L[4,]], - dtype[floating], -] -UnitQuaternionArray = ndarray[ - Tuple[L[4,]], - dtype[floating], -] +QuaternionArray = R4 +UnitQuaternionArray = R4 Rn = Union[R2, R3] From 768a805af492db4af9550bea8db5a9534289832e Mon Sep 17 00:00:00 2001 From: John Barnett Date: Wed, 4 Dec 2024 14:32:09 -0500 Subject: [PATCH 23/26] Separate function skewa into skewa2 and skewa3 to resolve overload-overlap mypy error --- spatialmath/base/transformsNd.py | 34 +++++++++++++++++++------------- tests/base/test_transformsNd.py | 18 ++++++++++++++--- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/spatialmath/base/transformsNd.py b/spatialmath/base/transformsNd.py index 78d5841f..54f0377b 100644 --- a/spatialmath/base/transformsNd.py +++ b/spatialmath/base/transformsNd.py @@ -590,14 +590,26 @@ def vex(s, check=False): # ---------------------------------------------------------------------------------------# -@overload -def skewa(v: ArrayLike3) -> se2Array: - ... +def skewa2(v: ArrayLike3) -> se2Array: + v = getvector(v, None) + if len(v) == 3: + omega = np.zeros((3, 3), dtype=v.dtype) + omega[:2, :2] = skew(v[2]) + omega[:2, 2] = v[0:2] + return omega + else: + raise ValueError("expecting a 3-vector") -@overload -def skewa(v: ArrayLike6) -> se3Array: - ... +def skewa3(v: ArrayLike6) -> se3Array: + v = getvector(v, None) + if len(v) == 6: + omega = np.zeros((4, 4), dtype=v.dtype) + omega[:3, :3] = skew(v[3:6]) + omega[:3, 3] = v[0:3] + return omega + else: + raise ValueError("expecting a 6-vector") def skewa(v: Union[ArrayLike3, ArrayLike6]) -> Union[se2Array, se3Array]: @@ -633,15 +645,9 @@ def skewa(v: Union[ArrayLike3, ArrayLike6]) -> Union[se2Array, se3Array]: v = getvector(v, None) if len(v) == 3: - omega = np.zeros((3, 3), dtype=v.dtype) - omega[:2, :2] = skew(v[2]) - omega[:2, 2] = v[0:2] - return omega + return skewa2(v) elif len(v) == 6: - omega = np.zeros((4, 4), dtype=v.dtype) - omega[:3, :3] = skew(v[3:6]) - omega[:3, 3] = v[0:3] - return omega + return skewa3(v) else: raise ValueError("expecting a 3- or 6-vector") diff --git a/tests/base/test_transformsNd.py b/tests/base/test_transformsNd.py index 92d9e2a3..4568e46f 100755 --- a/tests/base/test_transformsNd.py +++ b/tests/base/test_transformsNd.py @@ -361,23 +361,35 @@ def test_isskewa(self): sk[2, 2] = 3 self.assertFalse(isskew(sk)) - def test_skewa(self): + def test_skewa3(self): # 3D - sk = skewa([1, 2, 3, 4, 5, 6]) + sk = skewa3([1, 2, 3, 4, 5, 6]) self.assertEqual(sk.shape, (4, 4)) nt.assert_almost_equal(sk.diagonal(), np.r_[0, 0, 0, 0]) nt.assert_almost_equal(sk[-1, :], np.r_[0, 0, 0, 0]) nt.assert_almost_equal(sk[:3, 3], [1, 2, 3]) nt.assert_almost_equal(vex(sk[:3, :3]), [4, 5, 6]) + def test_skewa2(self): # 2D - sk = skewa([1, 2, 3]) + sk = skewa2([1, 2, 3]) self.assertEqual(sk.shape, (3, 3)) nt.assert_almost_equal(sk.diagonal(), np.r_[0, 0, 0]) nt.assert_almost_equal(sk[-1, :], np.r_[0, 0, 0]) nt.assert_almost_equal(sk[:2, 2], [1, 2]) nt.assert_almost_equal(vex(sk[:2, :2]), [3]) + def test_skewa_skewa3(self): + # 3D + v = [1, 2, 3, 4, 5, 6] + nt.assert_equal(skewa(v), skewa3(v)) + + def test_skewa_skewa2(self): + # 2D + v = [1, 2, 3] + nt.assert_equal(skewa(v), skewa2(v)) + + def test_skew_raises(self): with self.assertRaises(ValueError): sk = skew([1, 2]) From e95ce5c8dbcd8c14fb809dd66cdb64efe075e48a Mon Sep 17 00:00:00 2001 From: John Barnett Date: Wed, 4 Dec 2024 14:41:58 -0500 Subject: [PATCH 24/26] Update library code to use skewa2 or skewa3 directly --- spatialmath/base/__init__.py | 2 ++ spatialmath/base/transforms3d.py | 4 ++-- spatialmath/twist.py | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/spatialmath/base/__init__.py b/spatialmath/base/__init__.py index 98ff87f0..85e0dbed 100644 --- a/spatialmath/base/__init__.py +++ b/spatialmath/base/__init__.py @@ -286,6 +286,8 @@ "iseye", "skew", "vex", + "skewa2", + "skewa3", "skewa", "vexa", "h2e", diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index 5437c63c..d821c3d0 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -35,7 +35,7 @@ t2r, rt2tr, skew, - skewa, + skewa3, vex, vexa, isskew, @@ -1723,7 +1723,7 @@ def delta2tr(d: R6) -> SE3Array: :SymPy: supported """ - return np.eye(4, 4) + skewa(d) + return np.eye(4, 4) + skewa3(d) def trinv(T: SE3Array) -> SE3Array: diff --git a/spatialmath/twist.py b/spatialmath/twist.py index f84a0f1b..c4cda460 100644 --- a/spatialmath/twist.py +++ b/spatialmath/twist.py @@ -916,9 +916,9 @@ def skewa(self): >>> smb.trexp(se) """ if len(self) == 1: - return smb.skewa(self.S) + return smb.skewa3(self.S) else: - return [smb.skewa(x.S) for x in self] + return [smb.skewa3(x.S) for x in self] @property def pitch(self): @@ -1558,9 +1558,9 @@ def skewa(self): >>> smb.trexp2(se) """ if len(self) == 1: - return smb.skewa(self.S) + return smb.skewa2(self.S) else: - return [smb.skewa(x.S) for x in self] + return [smb.skewa2(x.S) for x in self] def exp(self, theta=1, unit="rad"): r""" From 13a920ac48c28d79f12ed88e9491c0b44ebe9337 Mon Sep 17 00:00:00 2001 From: John Barnett Date: Wed, 4 Dec 2024 14:57:56 -0500 Subject: [PATCH 25/26] Separate trexp into trexp_SO3 and trexp_SE3 to resolve overload-overlap typing error --- spatialmath/base/transforms3d.py | 129 +++++++++++++++++-------------- 1 file changed, 73 insertions(+), 56 deletions(-) diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index d821c3d0..b1e7e180 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -1406,26 +1406,92 @@ def trlog( # ---------------------------------------------------------------------------------------# @overload # pragma: no cover -def trexp(S: so3Array, theta: Optional[float] = None, check: bool = True) -> SO3Array: +def trexp_SO3(S: so3Array, theta: Optional[float] = None, check: bool = True) -> SO3Array: ... @overload # pragma: no cover -def trexp(S: se3Array, theta: Optional[float] = None, check: bool = True) -> SE3Array: +def trexp_SO3(S: ArrayLike3, theta: Optional[float] = None, check=True) -> SO3Array: ... +def trexp_SO3(S, theta=None, check=True): + if ismatrix(S, (3, 3)) or isvector(S, 3): + # so(3) case + if ismatrix(S, (3, 3)): + # skew symmetric matrix + if check and not isskew(S): + raise ValueError("argument must be a valid so(3) element") + w = vex(S) + else: + # 3 vector + w = getvector(S) + + if theta is not None and not isunitvec(w): + raise ValueError("If theta is specified S must be a unit twist") + + # do Rodrigues' formula for rotation + return rodrigues(w, theta) + else: + raise ValueError(" First argument must be SO(3) or 3-vector") + + @overload # pragma: no cover -def trexp(S: ArrayLike3, theta: Optional[float] = None, check=True) -> SO3Array: +def trexp_SE3(S: se3Array, theta: Optional[float] = None, check: bool = True) -> SE3Array: ... @overload # pragma: no cover -def trexp(S: ArrayLike6, theta: Optional[float] = None, check=True) -> SE3Array: +def trexp_SE3(S: ArrayLike6, theta: Optional[float] = None, check=True) -> SE3Array: ... -def trexp(S, theta=None, check=True): +def trexp_SE3(S, theta=None, check=True): + if ismatrix(S, (4, 4)) or isvector(S, 6): + # se(3) case + if ismatrix(S, (4, 4)): + # augmentented skew matrix + if check and not isskewa(S): + raise ValueError("argument must be a valid se(3) element") + tw = vexa(cast(se3Array, S)) + else: + # 6 vector + tw = getvector(S) + + if iszerovec(tw): + return np.eye(4) + + if theta is None: + (tw, theta) = unittwist_norm(tw) + else: + if theta == 0: + return np.eye(4) + elif not isunittwist(tw): + raise ValueError("If theta is specified S must be a unit twist") + + # tw is a unit twist, th is its magnitude + t = tw[0:3] + w = tw[3:6] + + R = rodrigues(w, theta) + + skw = skew(w) + V = ( + np.eye(3) * theta + + (1.0 - math.cos(theta)) * skw + + (theta - math.sin(theta)) * skw @ skw + ) + + return rt2tr(R, V @ t) + else: + raise ValueError(" First argument must be SE(3) or 6-vector") + + +def trexp( + S: so3Array | ArrayLike3 | se3Array | ArrayLike6, + theta: Optional[float] = None, + check: bool = True, +): """ Exponential of se(3) or so(3) matrix @@ -1486,58 +1552,9 @@ def trexp(S, theta=None, check=True): """ if ismatrix(S, (4, 4)) or isvector(S, 6): - # se(3) case - if ismatrix(S, (4, 4)): - # augmentented skew matrix - if check and not isskewa(S): - raise ValueError("argument must be a valid se(3) element") - tw = vexa(cast(se3Array, S)) - else: - # 6 vector - tw = getvector(S) - - if iszerovec(tw): - return np.eye(4) - - if theta is None: - (tw, theta) = unittwist_norm(tw) - else: - if theta == 0: - return np.eye(4) - elif not isunittwist(tw): - raise ValueError("If theta is specified S must be a unit twist") - - # tw is a unit twist, th is its magnitude - t = tw[0:3] - w = tw[3:6] - - R = rodrigues(w, theta) - - skw = skew(w) - V = ( - np.eye(3) * theta - + (1.0 - math.cos(theta)) * skw - + (theta - math.sin(theta)) * skw @ skw - ) - - return rt2tr(R, V @ t) - + return trexp_SE3(S, theta=theta, check=check) elif ismatrix(S, (3, 3)) or isvector(S, 3): - # so(3) case - if ismatrix(S, (3, 3)): - # skew symmetric matrix - if check and not isskew(S): - raise ValueError("argument must be a valid so(3) element") - w = vex(S) - else: - # 3 vector - w = getvector(S) - - if theta is not None and not isunitvec(w): - raise ValueError("If theta is specified S must be a unit twist") - - # do Rodrigues' formula for rotation - return rodrigues(w, theta) + return trexp_SO3(S, theta=theta, check=check) else: raise ValueError(" First argument must be SO(3), 3-vector, SE(3) or 6-vector") From 1769e81813e97e64abc671fe16a8cd0aef98ca2a Mon Sep 17 00:00:00 2001 From: John Barnett Date: Wed, 4 Dec 2024 15:06:51 -0500 Subject: [PATCH 26/26] Use trexp_SO3 and trexp_SE3 directly in library code for better typing --- spatialmath/base/transforms3d.py | 2 +- spatialmath/pose3d.py | 8 ++++---- spatialmath/twist.py | 14 +++++++------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index b1e7e180..a5ed52bd 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -2139,7 +2139,7 @@ def x2r(r: ArrayLike3, representation: str = "rpy/xyz") -> SO3Array: elif representation in ("arm", "vehicle", "camera"): R = rpy2r(r, order=representation) elif representation == "exp": - R = trexp(r) + R = trexp_SO3(r) else: raise ValueError(f"unknown representation: {representation}") return R diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index 40bfac59..beb79526 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -830,9 +830,9 @@ def Exp( :seealso: :func:`spatialmath.base.transforms3d.trexp`, :func:`spatialmath.base.transformsNd.skew` """ if smb.ismatrix(S, (-1, 3)) and not so3: - return cls([smb.trexp(s, check=check) for s in S], check=False) + return cls([smb.trexp_SO3(s, check=check) for s in S], check=False) else: - return cls(smb.trexp(cast(R3, S), check=check), check=False) + return cls(smb.trexp_SO3(cast(R3, S), check=check), check=False) def UnitQuaternion(self) -> UnitQuaternion: """ @@ -1828,9 +1828,9 @@ def Exp(cls, S: Union[R6, R4x4], check: bool = True) -> SE3: :seealso: :func:`~spatialmath.base.transforms3d.trexp`, :func:`~spatialmath.base.transformsNd.skew` """ if smb.isvector(S, 6): - return cls(smb.trexp(smb.getvector(S)), check=False) + return cls(smb.trexp_SE3(smb.getvector(S)), check=False) else: - return cls(smb.trexp(S), check=False) + return cls(smb.trexp_SE3(S), check=False) @classmethod def Delta(cls, d: ArrayLike6) -> SE3: diff --git a/spatialmath/twist.py b/spatialmath/twist.py index c4cda460..1a30a8e6 100644 --- a/spatialmath/twist.py +++ b/spatialmath/twist.py @@ -1015,13 +1015,13 @@ def SE3(self, theta=1, unit="rad"): if len(theta) == 1: # theta is a scalar - return SE3(smb.trexp(self.S * theta)) + return SE3(smb.trexp_SE3(self.S * theta)) else: # theta is a vector if len(self) == 1: - return SE3([smb.trexp(self.S * t) for t in theta]) + return SE3([smb.trexp_SE3(self.S * t) for t in theta]) elif len(self) == len(theta): - return SE3([smb.trexp(S * t) for S, t in zip(self.data, theta)]) + return SE3([smb.trexp_SE3(S * t) for S, t in zip(self.data, theta)]) else: raise ValueError("length of twist and theta not consistent") @@ -1072,9 +1072,9 @@ def exp(self, theta=1, unit="rad"): theta = smb.getunit(theta, unit) if len(self) == 1: - return SE3([smb.trexp(self.S * t) for t in theta], check=False) + return SE3([smb.trexp_SE3(self.S * t) for t in theta], check=False) elif len(self) == len(theta): - return SE3([smb.trexp(s * t) for s, t in zip(self.S, theta)], check=False) + return SE3([smb.trexp_SE3(s * t) for s, t in zip(self.S, theta)], check=False) else: raise ValueError("length mismatch") @@ -1136,12 +1136,12 @@ def __mul__( return Twist3( left.binop( right, - lambda x, y: smb.trlog(smb.trexp(x) @ smb.trexp(y), twist=True), + lambda x, y: smb.trlog(smb.trexp_SE3(x) @ smb.trexp_SE3(y), twist=True), ) ) elif isinstance(right, SE3): # twist * SE3 -> SE3 - return SE3(left.binop(right, lambda x, y: smb.trexp(x) @ y), check=False) + return SE3(left.binop(right, lambda x, y: smb.trexp_SE3(x) @ y), check=False) elif smb.isscalar(right): # return Twist(left.S * right) return Twist3(left.binop(right, lambda x, y: x * y))