From fb9147736d937cd97c16a4b0ffe1755b91d3474e Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Fri, 25 Jul 2025 19:02:56 +0900 Subject: [PATCH 1/4] stdlib(sqlite3): Raise ProgrammingError for missing named parameter (#6036) --- Lib/test/test_sqlite3/test_dbapi.py | 2 -- stdlib/src/sqlite.rs | 10 +++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 6c95a887a3..12cd997418 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -908,8 +908,6 @@ def __missing__(self, key): row = self.cu.fetchone() self.assertEqual(row[0], "foo") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_execute_dict_mapping_too_little_args(self): self.cu.execute("insert into test(name) values ('foo')") with self.assertRaises(sqlite.ProgrammingError): diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index 973fc431f8..8db90a7bba 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -2725,7 +2725,15 @@ mod _sqlite { let name = unsafe { name.add(1) }; let name = ptr_to_str(name, vm)?; - let val = dict.get_item(name, vm)?; + let val = match dict.get_item_opt(name, vm)? { + Some(val) => val, + None => { + return Err(new_programming_error( + vm, + format!("You did not supply a value for binding parameter :{name}.",), + )); + } + }; self.bind_parameter(i, &val, vm)?; } From ae03bacb3912b19daf21d08173f825becc737659 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Fri, 25 Jul 2025 12:04:49 +0200 Subject: [PATCH 2/4] Update `{_py,}decimal.py` from 3.13.5 (#6034) --- Lib/_pydecimal.py | 261 +++++------------- Lib/decimal.py | 108 +++++++- Lib/test/decimaltestdata/abs.decTest | 2 +- Lib/test/decimaltestdata/ddFMA.decTest | 2 +- Lib/test/decimaltestdata/ddQuantize.decTest | 2 +- Lib/test/decimaltestdata/ddRemainder.decTest | 2 +- .../decimaltestdata/ddRemainderNear.decTest | 2 +- Lib/test/decimaltestdata/dqRemainder.decTest | 2 +- .../decimaltestdata/dqRemainderNear.decTest | 2 +- Lib/test/decimaltestdata/exp.decTest | 2 +- Lib/test/decimaltestdata/extra.decTest | 2 +- Lib/test/decimaltestdata/remainder.decTest | 2 +- .../decimaltestdata/remainderNear.decTest | 2 +- Lib/test/test_decimal.py | 149 ++++++++-- 14 files changed, 315 insertions(+), 225 deletions(-) diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py index a0b608380a..ff80180a79 100644 --- a/Lib/_pydecimal.py +++ b/Lib/_pydecimal.py @@ -13,104 +13,7 @@ # bug) and will be backported. At this point the spec is stabilizing # and the updates are becoming fewer, smaller, and less significant. -""" -This is an implementation of decimal floating point arithmetic based on -the General Decimal Arithmetic Specification: - - http://speleotrove.com/decimal/decarith.html - -and IEEE standard 854-1987: - - http://en.wikipedia.org/wiki/IEEE_854-1987 - -Decimal floating point has finite precision with arbitrarily large bounds. - -The purpose of this module is to support arithmetic using familiar -"schoolhouse" rules and to avoid some of the tricky representation -issues associated with binary floating point. The package is especially -useful for financial applications or for contexts where users have -expectations that are at odds with binary floating point (for instance, -in binary floating point, 1.00 % 0.1 gives 0.09999999999999995 instead -of 0.0; Decimal('1.00') % Decimal('0.1') returns the expected -Decimal('0.00')). - -Here are some examples of using the decimal module: - ->>> from decimal import * ->>> setcontext(ExtendedContext) ->>> Decimal(0) -Decimal('0') ->>> Decimal('1') -Decimal('1') ->>> Decimal('-.0123') -Decimal('-0.0123') ->>> Decimal(123456) -Decimal('123456') ->>> Decimal('123.45e12345678') -Decimal('1.2345E+12345680') ->>> Decimal('1.33') + Decimal('1.27') -Decimal('2.60') ->>> Decimal('12.34') + Decimal('3.87') - Decimal('18.41') -Decimal('-2.20') ->>> dig = Decimal(1) ->>> print(dig / Decimal(3)) -0.333333333 ->>> getcontext().prec = 18 ->>> print(dig / Decimal(3)) -0.333333333333333333 ->>> print(dig.sqrt()) -1 ->>> print(Decimal(3).sqrt()) -1.73205080756887729 ->>> print(Decimal(3) ** 123) -4.85192780976896427E+58 ->>> inf = Decimal(1) / Decimal(0) ->>> print(inf) -Infinity ->>> neginf = Decimal(-1) / Decimal(0) ->>> print(neginf) --Infinity ->>> print(neginf + inf) -NaN ->>> print(neginf * inf) --Infinity ->>> print(dig / 0) -Infinity ->>> getcontext().traps[DivisionByZero] = 1 ->>> print(dig / 0) -Traceback (most recent call last): - ... - ... - ... -decimal.DivisionByZero: x / 0 ->>> c = Context() ->>> c.traps[InvalidOperation] = 0 ->>> print(c.flags[InvalidOperation]) -0 ->>> c.divide(Decimal(0), Decimal(0)) -Decimal('NaN') ->>> c.traps[InvalidOperation] = 1 ->>> print(c.flags[InvalidOperation]) -1 ->>> c.flags[InvalidOperation] = 0 ->>> print(c.flags[InvalidOperation]) -0 ->>> print(c.divide(Decimal(0), Decimal(0))) -Traceback (most recent call last): - ... - ... - ... -decimal.InvalidOperation: 0 / 0 ->>> print(c.flags[InvalidOperation]) -1 ->>> c.flags[InvalidOperation] = 0 ->>> c.traps[InvalidOperation] = 0 ->>> print(c.divide(Decimal(0), Decimal(0))) -NaN ->>> print(c.flags[InvalidOperation]) -1 ->>> -""" +"""Python decimal arithmetic module""" __all__ = [ # Two major classes @@ -140,8 +43,11 @@ # Limits for the C version for compatibility 'MAX_PREC', 'MAX_EMAX', 'MIN_EMIN', 'MIN_ETINY', - # C version: compile time choice that enables the thread local context - 'HAVE_THREADS' + # C version: compile time choice that enables the thread local context (deprecated, now always true) + 'HAVE_THREADS', + + # C version: compile time choice that enables the coroutine local context + 'HAVE_CONTEXTVAR' ] __xname__ = __name__ # sys.modules lookup (--without-threads) @@ -156,7 +62,7 @@ try: from collections import namedtuple as _namedtuple - DecimalTuple = _namedtuple('DecimalTuple', 'sign digits exponent') + DecimalTuple = _namedtuple('DecimalTuple', 'sign digits exponent', module='decimal') except ImportError: DecimalTuple = lambda *args: args @@ -172,6 +78,7 @@ # Compatibility with the C version HAVE_THREADS = True +HAVE_CONTEXTVAR = True if sys.maxsize == 2**63-1: MAX_PREC = 999999999999999999 MAX_EMAX = 999999999999999999 @@ -190,7 +97,7 @@ class DecimalException(ArithmeticError): Used exceptions derive from this. If an exception derives from another exception besides this (such as - Underflow (Inexact, Rounded, Subnormal) that indicates that it is only + Underflow (Inexact, Rounded, Subnormal)) that indicates that it is only called if the others are present. This isn't actually used for anything, though. @@ -238,7 +145,7 @@ class InvalidOperation(DecimalException): x ** (+-)INF An operand is invalid - The result of the operation after these is a quiet positive NaN, + The result of the operation after this is a quiet positive NaN, except when the cause is a signaling NaN, in which case the result is also a quiet NaN, but with the original sign, and an optional diagnostic information. @@ -431,82 +338,40 @@ class FloatOperation(DecimalException, TypeError): ##### Context Functions ################################################## # The getcontext() and setcontext() function manage access to a thread-local -# current context. Py2.4 offers direct support for thread locals. If that -# is not available, use threading.current_thread() which is slower but will -# work for older Pythons. If threads are not part of the build, create a -# mock threading object with threading.local() returning the module namespace. - -try: - import threading -except ImportError: - # Python was compiled without threads; create a mock object instead - class MockThreading(object): - def local(self, sys=sys): - return sys.modules[__xname__] - threading = MockThreading() - del MockThreading - -try: - threading.local - -except AttributeError: - - # To fix reloading, force it to create a new context - # Old contexts have different exceptions in their dicts, making problems. - if hasattr(threading.current_thread(), '__decimal_context__'): - del threading.current_thread().__decimal_context__ +# current context. - def setcontext(context): - """Set this thread's context to context.""" - if context in (DefaultContext, BasicContext, ExtendedContext): - context = context.copy() - context.clear_flags() - threading.current_thread().__decimal_context__ = context +import contextvars - def getcontext(): - """Returns this thread's context. +_current_context_var = contextvars.ContextVar('decimal_context') - If this thread does not yet have a context, returns - a new context and sets this thread's context. - New contexts are copies of DefaultContext. - """ - try: - return threading.current_thread().__decimal_context__ - except AttributeError: - context = Context() - threading.current_thread().__decimal_context__ = context - return context +_context_attributes = frozenset( + ['prec', 'Emin', 'Emax', 'capitals', 'clamp', 'rounding', 'flags', 'traps'] +) -else: +def getcontext(): + """Returns this thread's context. - local = threading.local() - if hasattr(local, '__decimal_context__'): - del local.__decimal_context__ + If this thread does not yet have a context, returns + a new context and sets this thread's context. + New contexts are copies of DefaultContext. + """ + try: + return _current_context_var.get() + except LookupError: + context = Context() + _current_context_var.set(context) + return context + +def setcontext(context): + """Set this thread's context to context.""" + if context in (DefaultContext, BasicContext, ExtendedContext): + context = context.copy() + context.clear_flags() + _current_context_var.set(context) - def getcontext(_local=local): - """Returns this thread's context. +del contextvars # Don't contaminate the namespace - If this thread does not yet have a context, returns - a new context and sets this thread's context. - New contexts are copies of DefaultContext. - """ - try: - return _local.__decimal_context__ - except AttributeError: - context = Context() - _local.__decimal_context__ = context - return context - - def setcontext(context, _local=local): - """Set this thread's context to context.""" - if context in (DefaultContext, BasicContext, ExtendedContext): - context = context.copy() - context.clear_flags() - _local.__decimal_context__ = context - - del threading, local # Don't contaminate the namespace - -def localcontext(ctx=None): +def localcontext(ctx=None, **kwargs): """Return a context manager for a copy of the supplied context Uses a copy of the current context if no context is specified @@ -542,8 +407,14 @@ def sin(x): >>> print(getcontext().prec) 28 """ - if ctx is None: ctx = getcontext() - return _ContextManager(ctx) + if ctx is None: + ctx = getcontext() + ctx_manager = _ContextManager(ctx) + for key, value in kwargs.items(): + if key not in _context_attributes: + raise TypeError(f"'{key}' is an invalid keyword argument for this function") + setattr(ctx_manager.new_context, key, value) + return ctx_manager ##### Decimal class ####################################################### @@ -553,7 +424,7 @@ def sin(x): # numbers.py for more detail. class Decimal(object): - """Floating point class for decimal arithmetic.""" + """Floating-point class for decimal arithmetic.""" __slots__ = ('_exp','_int','_sign', '_is_special') # Generally, the value of the Decimal instance is given by @@ -993,7 +864,7 @@ def __hash__(self): if self.is_snan(): raise TypeError('Cannot hash a signaling NaN value.') elif self.is_nan(): - return _PyHASH_NAN + return object.__hash__(self) else: if self._sign: return -_PyHASH_INF @@ -1674,13 +1545,13 @@ def __int__(self): __trunc__ = __int__ + @property def real(self): return self - real = property(real) + @property def imag(self): return Decimal(0) - imag = property(imag) def conjugate(self): return self @@ -2260,10 +2131,16 @@ def _power_exact(self, other, p): else: return None - if xc >= 10**p: + # An exact power of 10 is representable, but can convert to a + # string of any length. But an exact power of 10 shouldn't be + # possible at this point. + assert xc > 1, self + assert xc % 10 != 0, self + strxc = str(xc) + if len(strxc) > p: return None xe = -e-xe - return _dec_from_triple(0, str(xc), xe) + return _dec_from_triple(0, strxc, xe) # now y is positive; find m and n such that y = m/n if ye >= 0: @@ -2272,7 +2149,7 @@ def _power_exact(self, other, p): if xe != 0 and len(str(abs(yc*xe))) <= -ye: return None xc_bits = _nbits(xc) - if xc != 1 and len(str(abs(yc)*xc_bits)) <= -ye: + if len(str(abs(yc)*xc_bits)) <= -ye: return None m, n = yc, 10**(-ye) while m % 2 == n % 2 == 0: @@ -2285,7 +2162,7 @@ def _power_exact(self, other, p): # compute nth root of xc*10**xe if n > 1: # if 1 < xc < 2**n then xc isn't an nth power - if xc != 1 and xc_bits <= n: + if xc_bits <= n: return None xe, rem = divmod(xe, n) @@ -2313,13 +2190,18 @@ def _power_exact(self, other, p): return None xc = xc**m xe *= m - if xc > 10**p: + # An exact power of 10 is representable, but can convert to a string + # of any length. But an exact power of 10 shouldn't be possible at + # this point. + assert xc > 1, self + assert xc % 10 != 0, self + str_xc = str(xc) + if len(str_xc) > p: return None # by this point the result *is* exactly representable # adjust the exponent to get as close as possible to the ideal # exponent, if necessary - str_xc = str(xc) if other._isinteger() and other._sign == 0: ideal_exponent = self._exp*int(other) zeros = min(xe-ideal_exponent, p-len(str_xc)) @@ -3837,6 +3719,10 @@ def __format__(self, specifier, context=None, _localeconv=None): # represented in fixed point; rescale them to 0e0. if not self and self._exp > 0 and spec['type'] in 'fF%': self = self._rescale(0, rounding) + if not self and spec['no_neg_0'] and self._sign: + adjusted_sign = 0 + else: + adjusted_sign = self._sign # figure out placement of the decimal point leftdigits = self._exp + len(self._int) @@ -3867,7 +3753,7 @@ def __format__(self, specifier, context=None, _localeconv=None): # done with the decimal-specific stuff; hand over the rest # of the formatting to the _format_number function - return _format_number(self._sign, intpart, fracpart, exp, spec) + return _format_number(adjusted_sign, intpart, fracpart, exp, spec) def _dec_from_triple(sign, coefficient, exponent, special=False): """Create a decimal instance directly, without any validation, @@ -5677,8 +5563,6 @@ def __init__(self, value=None): def __repr__(self): return "(%r, %r, %r)" % (self.sign, self.int, self.exp) - __str__ = __repr__ - def _normalize(op1, op2, prec = 0): @@ -6187,7 +6071,7 @@ def _convert_for_comparison(self, other, equality_op=False): # # A format specifier for Decimal looks like: # -# [[fill]align][sign][#][0][minimumwidth][,][.precision][type] +# [[fill]align][sign][z][#][0][minimumwidth][,][.precision][type] _parse_format_specifier_regex = re.compile(r"""\A (?: @@ -6195,6 +6079,7 @@ def _convert_for_comparison(self, other, equality_op=False): (?P[<>=^]) )? (?P[-+ ])? +(?Pz)? (?P\#)? (?P0)? (?P(?!0)\d+)? diff --git a/Lib/decimal.py b/Lib/decimal.py index 7746ea2601..ee3147f5dd 100644 --- a/Lib/decimal.py +++ b/Lib/decimal.py @@ -1,11 +1,109 @@ +"""Decimal fixed-point and floating-point arithmetic. + +This is an implementation of decimal floating-point arithmetic based on +the General Decimal Arithmetic Specification: + + http://speleotrove.com/decimal/decarith.html + +and IEEE standard 854-1987: + + http://en.wikipedia.org/wiki/IEEE_854-1987 + +Decimal floating point has finite precision with arbitrarily large bounds. + +The purpose of this module is to support arithmetic using familiar +"schoolhouse" rules and to avoid some of the tricky representation +issues associated with binary floating point. The package is especially +useful for financial applications or for contexts where users have +expectations that are at odds with binary floating point (for instance, +in binary floating point, 1.00 % 0.1 gives 0.09999999999999995 instead +of 0.0; Decimal('1.00') % Decimal('0.1') returns the expected +Decimal('0.00')). + +Here are some examples of using the decimal module: + +>>> from decimal import * +>>> setcontext(ExtendedContext) +>>> Decimal(0) +Decimal('0') +>>> Decimal('1') +Decimal('1') +>>> Decimal('-.0123') +Decimal('-0.0123') +>>> Decimal(123456) +Decimal('123456') +>>> Decimal('123.45e12345678') +Decimal('1.2345E+12345680') +>>> Decimal('1.33') + Decimal('1.27') +Decimal('2.60') +>>> Decimal('12.34') + Decimal('3.87') - Decimal('18.41') +Decimal('-2.20') +>>> dig = Decimal(1) +>>> print(dig / Decimal(3)) +0.333333333 +>>> getcontext().prec = 18 +>>> print(dig / Decimal(3)) +0.333333333333333333 +>>> print(dig.sqrt()) +1 +>>> print(Decimal(3).sqrt()) +1.73205080756887729 +>>> print(Decimal(3) ** 123) +4.85192780976896427E+58 +>>> inf = Decimal(1) / Decimal(0) +>>> print(inf) +Infinity +>>> neginf = Decimal(-1) / Decimal(0) +>>> print(neginf) +-Infinity +>>> print(neginf + inf) +NaN +>>> print(neginf * inf) +-Infinity +>>> print(dig / 0) +Infinity +>>> getcontext().traps[DivisionByZero] = 1 +>>> print(dig / 0) +Traceback (most recent call last): + ... + ... + ... +decimal.DivisionByZero: x / 0 +>>> c = Context() +>>> c.traps[InvalidOperation] = 0 +>>> print(c.flags[InvalidOperation]) +0 +>>> c.divide(Decimal(0), Decimal(0)) +Decimal('NaN') +>>> c.traps[InvalidOperation] = 1 +>>> print(c.flags[InvalidOperation]) +1 +>>> c.flags[InvalidOperation] = 0 +>>> print(c.flags[InvalidOperation]) +0 +>>> print(c.divide(Decimal(0), Decimal(0))) +Traceback (most recent call last): + ... + ... + ... +decimal.InvalidOperation: 0 / 0 +>>> print(c.flags[InvalidOperation]) +1 +>>> c.flags[InvalidOperation] = 0 +>>> c.traps[InvalidOperation] = 0 +>>> print(c.divide(Decimal(0), Decimal(0))) +NaN +>>> print(c.flags[InvalidOperation]) +1 +>>> +""" try: from _decimal import * - from _decimal import __doc__ from _decimal import __version__ from _decimal import __libmpdec_version__ except ImportError: - from _pydecimal import * - from _pydecimal import __doc__ - from _pydecimal import __version__ - from _pydecimal import __libmpdec_version__ + import _pydecimal + import sys + _pydecimal.__doc__ = __doc__ + sys.modules[__name__] = _pydecimal diff --git a/Lib/test/decimaltestdata/abs.decTest b/Lib/test/decimaltestdata/abs.decTest index 01f73d7766..569b8fcd84 100644 --- a/Lib/test/decimaltestdata/abs.decTest +++ b/Lib/test/decimaltestdata/abs.decTest @@ -20,7 +20,7 @@ version: 2.59 -- This set of tests primarily tests the existence of the operator. --- Additon, subtraction, rounding, and more overflows are tested +-- Addition, subtraction, rounding, and more overflows are tested -- elsewhere. precision: 9 diff --git a/Lib/test/decimaltestdata/ddFMA.decTest b/Lib/test/decimaltestdata/ddFMA.decTest index 9094fc015b..7f2e523037 100644 --- a/Lib/test/decimaltestdata/ddFMA.decTest +++ b/Lib/test/decimaltestdata/ddFMA.decTest @@ -1663,7 +1663,7 @@ ddfma375087 fma 1 12345678 1E-33 -> 12345678.00000001 Inexac ddfma375088 fma 1 12345678 1E-34 -> 12345678.00000001 Inexact Rounded ddfma375089 fma 1 12345678 1E-35 -> 12345678.00000001 Inexact Rounded --- desctructive subtraction (from remainder tests) +-- destructive subtraction (from remainder tests) -- +++ some of these will be off-by-one remainder vs remainderNear diff --git a/Lib/test/decimaltestdata/ddQuantize.decTest b/Lib/test/decimaltestdata/ddQuantize.decTest index 9177620169..e1c5674d9a 100644 --- a/Lib/test/decimaltestdata/ddQuantize.decTest +++ b/Lib/test/decimaltestdata/ddQuantize.decTest @@ -462,7 +462,7 @@ ddqua520 quantize 1.234 1e359 -> 0E+359 Inexact Rounded ddqua521 quantize 123.456 1e359 -> 0E+359 Inexact Rounded ddqua522 quantize 1.234 1e359 -> 0E+359 Inexact Rounded ddqua523 quantize 123.456 1e359 -> 0E+359 Inexact Rounded --- next four are "won't fit" overfl +-- next four are "won't fit" overflow ddqua526 quantize 1.234 1e-299 -> NaN Invalid_operation ddqua527 quantize 123.456 1e-299 -> NaN Invalid_operation ddqua528 quantize 1.234 1e-299 -> NaN Invalid_operation diff --git a/Lib/test/decimaltestdata/ddRemainder.decTest b/Lib/test/decimaltestdata/ddRemainder.decTest index 5bd1e32d01..b1866d39a2 100644 --- a/Lib/test/decimaltestdata/ddRemainder.decTest +++ b/Lib/test/decimaltestdata/ddRemainder.decTest @@ -422,7 +422,7 @@ ddrem757 remainder 1 sNaN -> NaN Invalid_operation ddrem758 remainder 1000 sNaN -> NaN Invalid_operation ddrem759 remainder Inf -sNaN -> -NaN Invalid_operation --- propaging NaNs +-- propagating NaNs ddrem760 remainder NaN1 NaN7 -> NaN1 ddrem761 remainder sNaN2 NaN8 -> NaN2 Invalid_operation ddrem762 remainder NaN3 sNaN9 -> NaN9 Invalid_operation diff --git a/Lib/test/decimaltestdata/ddRemainderNear.decTest b/Lib/test/decimaltestdata/ddRemainderNear.decTest index 6ba64ebafe..bbe82ea374 100644 --- a/Lib/test/decimaltestdata/ddRemainderNear.decTest +++ b/Lib/test/decimaltestdata/ddRemainderNear.decTest @@ -450,7 +450,7 @@ ddrmn757 remaindernear 1 sNaN -> NaN Invalid_operation ddrmn758 remaindernear 1000 sNaN -> NaN Invalid_operation ddrmn759 remaindernear Inf -sNaN -> -NaN Invalid_operation --- propaging NaNs +-- propagating NaNs ddrmn760 remaindernear NaN1 NaN7 -> NaN1 ddrmn761 remaindernear sNaN2 NaN8 -> NaN2 Invalid_operation ddrmn762 remaindernear NaN3 sNaN9 -> NaN9 Invalid_operation diff --git a/Lib/test/decimaltestdata/dqRemainder.decTest b/Lib/test/decimaltestdata/dqRemainder.decTest index bae8eae526..e0aaca3747 100644 --- a/Lib/test/decimaltestdata/dqRemainder.decTest +++ b/Lib/test/decimaltestdata/dqRemainder.decTest @@ -418,7 +418,7 @@ dqrem757 remainder 1 sNaN -> NaN Invalid_operation dqrem758 remainder 1000 sNaN -> NaN Invalid_operation dqrem759 remainder Inf -sNaN -> -NaN Invalid_operation --- propaging NaNs +-- propagating NaNs dqrem760 remainder NaN1 NaN7 -> NaN1 dqrem761 remainder sNaN2 NaN8 -> NaN2 Invalid_operation dqrem762 remainder NaN3 sNaN9 -> NaN9 Invalid_operation diff --git a/Lib/test/decimaltestdata/dqRemainderNear.decTest b/Lib/test/decimaltestdata/dqRemainderNear.decTest index b850626fe4..2c5c3f5074 100644 --- a/Lib/test/decimaltestdata/dqRemainderNear.decTest +++ b/Lib/test/decimaltestdata/dqRemainderNear.decTest @@ -450,7 +450,7 @@ dqrmn757 remaindernear 1 sNaN -> NaN Invalid_operation dqrmn758 remaindernear 1000 sNaN -> NaN Invalid_operation dqrmn759 remaindernear Inf -sNaN -> -NaN Invalid_operation --- propaging NaNs +-- propagating NaNs dqrmn760 remaindernear NaN1 NaN7 -> NaN1 dqrmn761 remaindernear sNaN2 NaN8 -> NaN2 Invalid_operation dqrmn762 remaindernear NaN3 sNaN9 -> NaN9 Invalid_operation diff --git a/Lib/test/decimaltestdata/exp.decTest b/Lib/test/decimaltestdata/exp.decTest index 6a7af23b62..e01d7a8f92 100644 --- a/Lib/test/decimaltestdata/exp.decTest +++ b/Lib/test/decimaltestdata/exp.decTest @@ -28,7 +28,7 @@ rounding: half_even maxExponent: 384 minexponent: -383 --- basics (examples in specificiation, etc.) +-- basics (examples in specification, etc.) expx001 exp -Infinity -> 0 expx002 exp -10 -> 0.0000453999298 Inexact Rounded expx003 exp -1 -> 0.367879441 Inexact Rounded diff --git a/Lib/test/decimaltestdata/extra.decTest b/Lib/test/decimaltestdata/extra.decTest index b630d8e3f9..31291202a3 100644 --- a/Lib/test/decimaltestdata/extra.decTest +++ b/Lib/test/decimaltestdata/extra.decTest @@ -156,7 +156,7 @@ extr1302 fma -Inf 0E-456 sNaN148 -> NaN Invalid_operation -- max/min/max_mag/min_mag bug in 2.5.2/2.6/3.0: max(NaN, finite) gave -- incorrect answers when the finite number required rounding; similarly --- for the other thre functions +-- for the other three functions maxexponent: 999 minexponent: -999 precision: 6 diff --git a/Lib/test/decimaltestdata/remainder.decTest b/Lib/test/decimaltestdata/remainder.decTest index 7a1061b1e6..4f59b33287 100644 --- a/Lib/test/decimaltestdata/remainder.decTest +++ b/Lib/test/decimaltestdata/remainder.decTest @@ -435,7 +435,7 @@ remx757 remainder 1 sNaN -> NaN Invalid_operation remx758 remainder 1000 sNaN -> NaN Invalid_operation remx759 remainder Inf -sNaN -> -NaN Invalid_operation --- propaging NaNs +-- propagating NaNs remx760 remainder NaN1 NaN7 -> NaN1 remx761 remainder sNaN2 NaN8 -> NaN2 Invalid_operation remx762 remainder NaN3 sNaN9 -> NaN9 Invalid_operation diff --git a/Lib/test/decimaltestdata/remainderNear.decTest b/Lib/test/decimaltestdata/remainderNear.decTest index b768b9e0cf..000b1424d8 100644 --- a/Lib/test/decimaltestdata/remainderNear.decTest +++ b/Lib/test/decimaltestdata/remainderNear.decTest @@ -498,7 +498,7 @@ rmnx758 remaindernear 1000 sNaN -> NaN Invalid_operation rmnx759 remaindernear Inf sNaN -> NaN Invalid_operation rmnx760 remaindernear NaN sNaN -> NaN Invalid_operation --- propaging NaNs +-- propagating NaNs rmnx761 remaindernear NaN1 NaN7 -> NaN1 rmnx762 remaindernear sNaN2 NaN8 -> NaN2 Invalid_operation rmnx763 remaindernear NaN3 -sNaN9 -> -NaN9 Invalid_operation diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 0493d6a41d..163ca92bb4 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -24,6 +24,7 @@ with the corresponding argument. """ +import logging import math import os, sys import operator @@ -812,8 +813,6 @@ def test_explicit_context_create_from_float(self): x = random.expovariate(0.01) * (random.random() * 2.0 - 1.0) self.assertEqual(x, float(nc.create_decimal(x))) # roundtrip - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_unicode_digits(self): Decimal = self.decimal.Decimal @@ -831,6 +830,11 @@ class CExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase): class PyExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase): decimal = P + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_unicode_digits(self): # TODO(RUSTPYTHON): Remove this test when it pass + return super().test_unicode_digits() + class ImplicitConstructionTest: '''Unit tests for Implicit Construction cases of Decimal.''' @@ -916,8 +920,6 @@ class PyImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase): class FormatTest: '''Unit tests for the format function.''' - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_formatting(self): Decimal = self.decimal.Decimal @@ -1113,6 +1115,13 @@ def test_formatting(self): ('z>z6.1f', '-0.', 'zzz0.0'), ('x>z6.1f', '-0.', 'xxx0.0'), ('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char + ('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char + + # issue 114563 ('z' format on F type in cdecimal) + ('z3,.10F', '-6.24E-323', '0.0000000000'), + + # issue 91060 ('#' format in cdecimal) + ('#', '0', '0.'), # issue 6850 ('a=-7.0', '0.12345', 'aaaa0.1'), @@ -1128,8 +1137,6 @@ def test_formatting(self): # bytes format argument self.assertRaises(TypeError, Decimal(1).__format__, b'-020') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_negative_zero_format_directed_rounding(self): with self.decimal.localcontext() as ctx: ctx.rounding = ROUND_CEILING @@ -1228,7 +1235,31 @@ def get_fmt(x, override=None, fmt='n'): self.assertEqual(get_fmt(Decimal('-1.5'), dotsep_wide, '020n'), '-0\u00b4000\u00b4000\u00b4000\u00b4001\u00bf5') - @run_with_locale('LC_ALL', 'ps_AF') + def test_deprecated_N_format(self): + Decimal = self.decimal.Decimal + h = Decimal('6.62607015e-34') + if self.decimal == C: + with self.assertWarns(DeprecationWarning) as cm: + r = format(h, 'N') + self.assertEqual(cm.filename, __file__) + self.assertEqual(r, format(h, 'n').upper()) + with self.assertWarns(DeprecationWarning) as cm: + r = format(h, '010.3N') + self.assertEqual(cm.filename, __file__) + self.assertEqual(r, format(h, '010.3n').upper()) + else: + self.assertRaises(ValueError, format, h, 'N') + self.assertRaises(ValueError, format, h, '010.3N') + with warnings_helper.check_no_warnings(self): + self.assertEqual(format(h, 'N>10.3'), 'NN6.63E-34') + self.assertEqual(format(h, 'N>10.3n'), 'NN6.63e-34') + self.assertEqual(format(h, 'N>10.3e'), 'N6.626e-34') + self.assertEqual(format(h, 'N>10.3f'), 'NNNNN0.000') + self.assertRaises(ValueError, format, h, '>Nf') + self.assertRaises(ValueError, format, h, '10Nf') + self.assertRaises(ValueError, format, h, 'Nx') + + @run_with_locale('LC_ALL', 'ps_AF', '') def test_wide_char_separator_decimal_point(self): # locale with wide char separator and decimal point Decimal = self.decimal.Decimal @@ -1911,8 +1942,6 @@ def hashit(d): x = 1100 ** 1248 self.assertEqual(hashit(Decimal(x)), hashit(x)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_hash_method_nan(self): Decimal = self.decimal.Decimal self.assertRaises(TypeError, hash, Decimal('sNaN')) @@ -2048,7 +2077,9 @@ def test_tonum_methods(self): #to quantize, which is already extensively tested test_triples = [ ('123.456', -4, '0E+4'), + ('-123.456', -4, '-0E+4'), ('123.456', -3, '0E+3'), + ('-123.456', -3, '-0E+3'), ('123.456', -2, '1E+2'), ('123.456', -1, '1.2E+2'), ('123.456', 0, '123'), @@ -2718,8 +2749,6 @@ def test_quantize(self): x = d.quantize(context=c, exp=Decimal("1e797"), rounding=ROUND_DOWN) self.assertEqual(x, Decimal('8.71E+799')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_complex(self): Decimal = self.decimal.Decimal @@ -2894,6 +2923,11 @@ class CPythonAPItests(PythonAPItests, unittest.TestCase): class PyPythonAPItests(PythonAPItests, unittest.TestCase): decimal = P + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_complex(self): # TODO(RUSTPYTHON): Remove this test when it pass + return super().test_complex() + class ContextAPItests: def test_none_args(self): @@ -3663,8 +3697,6 @@ def test_localcontextarg(self): self.assertIsNot(new_ctx, set_ctx, 'did not copy the context') self.assertIs(set_ctx, enter_ctx, '__enter__ returned wrong context') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_localcontext_kwargs(self): with self.decimal.localcontext( prec=10, rounding=ROUND_HALF_DOWN, @@ -3693,8 +3725,6 @@ def test_localcontext_kwargs(self): self.assertRaises(TypeError, self.decimal.localcontext, Emin="") self.assertRaises(TypeError, self.decimal.localcontext, Emax="") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_local_context_kwargs_does_not_overwrite_existing_argument(self): ctx = self.decimal.getcontext() orig_prec = ctx.prec @@ -4362,7 +4392,8 @@ def test_module_attributes(self): self.assertEqual(C.__version__, P.__version__) - self.assertEqual(dir(C), dir(P)) + self.assertLessEqual(set(dir(C)), set(dir(P))) + self.assertEqual([n for n in dir(C) if n[:2] != '__'], sorted(P.__all__)) def test_context_attributes(self): @@ -4438,6 +4469,15 @@ def test_implicit_context(self): self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True) # three arg power self.assertEqual(pow(Decimal(10), 2, 7), 2) + if self.decimal == C: + self.assertEqual(pow(10, Decimal(2), 7), 2) + self.assertEqual(pow(10, 2, Decimal(7)), 2) + else: + # XXX: Three-arg power doesn't use __rpow__. + self.assertRaises(TypeError, pow, 10, Decimal(2), 7) + # XXX: There is no special method to dispatch on the + # third arg of three-arg power. + self.assertRaises(TypeError, pow, 10, 2, Decimal(7)) # exp self.assertEqual(Decimal("1.01").exp(), 3) # is_normal @@ -4648,6 +4688,11 @@ def tearDown(self): sys.set_int_max_str_digits(self._previous_int_limit) super().tearDown() + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_implicit_context(self): # TODO(RUSTPYTHON): Remove this test when it pass + return super().test_implicit_context() + class PyFunctionality(unittest.TestCase): """Extra functionality in decimal.py""" @@ -4699,9 +4744,33 @@ def test_py_exact_power(self): c.prec = 1 x = Decimal("152587890625") ** Decimal('-0.5') + self.assertEqual(x, Decimal('3e-6')) + c.prec = 2 + x = Decimal("152587890625") ** Decimal('-0.5') + self.assertEqual(x, Decimal('2.6e-6')) + c.prec = 3 + x = Decimal("152587890625") ** Decimal('-0.5') + self.assertEqual(x, Decimal('2.56e-6')) + c.prec = 28 + x = Decimal("152587890625") ** Decimal('-0.5') + self.assertEqual(x, Decimal('2.56e-6')) + c.prec = 201 x = Decimal(2**578) ** Decimal("-0.5") + # See https://github.com/python/cpython/issues/118027 + # Testing for an exact power could appear to hang, in the Python + # version, as it attempted to compute 10**(MAX_EMAX + 1). + # Fixed via https://github.com/python/cpython/pull/118503. + c.prec = P.MAX_PREC + c.Emax = P.MAX_EMAX + c.Emin = P.MIN_EMIN + c.traps[P.Inexact] = 1 + D2 = Decimal(2) + # If the bug is still present, the next statement won't complete. + res = D2 ** 117 + self.assertEqual(res, 1 << 117) + def test_py_immutability_operations(self): # Do operations and check that it didn't change internal objects. Decimal = P.Decimal @@ -5625,6 +5694,25 @@ def __abs__(self): self.assertEqual(Decimal.from_float(cls(101.1)), Decimal.from_float(101.1)) + def test_c_immutable_types(self): + SignalDict = type(C.Context().flags) + SignalDictMixin = SignalDict.__bases__[0] + ContextManager = type(C.localcontext()) + types = ( + SignalDictMixin, + ContextManager, + C.Decimal, + C.Context, + ) + for tp in types: + with self.subTest(tp=tp): + with self.assertRaisesRegex(TypeError, "immutable"): + tp.foo = 1 + + def test_c_disallow_instantiation(self): + ContextManager = type(C.localcontext()) + check_disallow_instantiation(self, ContextManager) + def test_c_signaldict_segfault(self): # See gh-106263 for details. SignalDict = type(C.Context().flags) @@ -5655,6 +5743,20 @@ def test_c_signaldict_segfault(self): with self.assertRaisesRegex(ValueError, err_msg): sd.copy() + def test_format_fallback_capitals(self): + # Fallback to _pydecimal formatting (triggered by `#` format which + # is unsupported by mpdecimal) should honor the current context. + x = C.Decimal('6.09e+23') + self.assertEqual(format(x, '#'), '6.09E+23') + with C.localcontext(capitals=0): + self.assertEqual(format(x, '#'), '6.09e+23') + + def test_format_fallback_rounding(self): + y = C.Decimal('6.09') + self.assertEqual(format(y, '#.1f'), '6.1') + with C.localcontext(rounding=C.ROUND_DOWN): + self.assertEqual(format(y, '#.1f'), '6.0') + @requires_docstrings @requires_cdecimal class SignatureTest(unittest.TestCase): @@ -5818,13 +5920,17 @@ def load_tests(loader, tests, pattern): if TODO_TESTS is None: from doctest import DocTestSuite, IGNORE_EXCEPTION_DETAIL + orig_context = orig_sys_decimal.getcontext().copy() for mod in C, P: if not mod: continue def setUp(slf, mod=mod): sys.modules['decimal'] = mod - def tearDown(slf): + init(mod) + def tearDown(slf, mod=mod): sys.modules['decimal'] = orig_sys_decimal + mod.setcontext(ORIGINAL_CONTEXT[mod].copy()) + orig_sys_decimal.setcontext(orig_context.copy()) optionflags = IGNORE_EXCEPTION_DETAIL if mod is C else 0 sys.modules['decimal'] = mod tests.addTest(DocTestSuite(mod, setUp=setUp, tearDown=tearDown, @@ -5839,11 +5945,12 @@ def setUpModule(): TEST_ALL = ARITH if ARITH is not None else is_resource_enabled('decimal') def tearDownModule(): - if C: C.setcontext(ORIGINAL_CONTEXT[C]) - P.setcontext(ORIGINAL_CONTEXT[P]) + if C: C.setcontext(ORIGINAL_CONTEXT[C].copy()) + P.setcontext(ORIGINAL_CONTEXT[P].copy()) if not C: - warnings.warn('C tests skipped: no module named _decimal.', - UserWarning) + logging.getLogger(__name__).warning( + 'C tests skipped: no module named _decimal.' + ) if not orig_sys_decimal is sys.modules['decimal']: raise TestFailed("Internal error: unbalanced number of changes to " "sys.modules['decimal'].") From c232b7f1f881f4543fb8e703afc4fddbf9b9190e Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Fri, 25 Jul 2025 12:08:01 +0200 Subject: [PATCH 3/4] Update `csv.py` from 3.13.5 (#6035) * Use `raise_if_stop!` macro * Update `csv.py` from 3.13.5 * Mark failing tests --- Lib/csv.py | 80 ++++++++++-- Lib/test/test_csv.py | 297 +++++++++++++++++++++++++++++++++---------- stdlib/src/csv.rs | 6 +- 3 files changed, 305 insertions(+), 78 deletions(-) diff --git a/Lib/csv.py b/Lib/csv.py index 77f30c8d2b..cd20265987 100644 --- a/Lib/csv.py +++ b/Lib/csv.py @@ -1,28 +1,90 @@ -""" -csv.py - read/write/investigate CSV files +r""" +CSV parsing and writing. + +This module provides classes that assist in the reading and writing +of Comma Separated Value (CSV) files, and implements the interface +described by PEP 305. Although many CSV files are simple to parse, +the format is not formally defined by a stable specification and +is subtle enough that parsing lines of a CSV file with something +like line.split(",") is bound to fail. The module supports three +basic APIs: reading, writing, and registration of dialects. + + +DIALECT REGISTRATION: + +Readers and writers support a dialect argument, which is a convenient +handle on a group of settings. When the dialect argument is a string, +it identifies one of the dialects previously registered with the module. +If it is a class or instance, the attributes of the argument are used as +the settings for the reader or writer: + + class excel: + delimiter = ',' + quotechar = '"' + escapechar = None + doublequote = True + skipinitialspace = False + lineterminator = '\r\n' + quoting = QUOTE_MINIMAL + +SETTINGS: + + * quotechar - specifies a one-character string to use as the + quoting character. It defaults to '"'. + * delimiter - specifies a one-character string to use as the + field separator. It defaults to ','. + * skipinitialspace - specifies how to interpret spaces which + immediately follow a delimiter. It defaults to False, which + means that spaces immediately following a delimiter is part + of the following field. + * lineterminator - specifies the character sequence which should + terminate rows. + * quoting - controls when quotes should be generated by the writer. + It can take on any of the following module constants: + + csv.QUOTE_MINIMAL means only when required, for example, when a + field contains either the quotechar or the delimiter + csv.QUOTE_ALL means that quotes are always placed around fields. + csv.QUOTE_NONNUMERIC means that quotes are always placed around + fields which do not parse as integers or floating-point + numbers. + csv.QUOTE_STRINGS means that quotes are always placed around + fields which are strings. Note that the Python value None + is not a string. + csv.QUOTE_NOTNULL means that quotes are only placed around fields + that are not the Python value None. + csv.QUOTE_NONE means that quotes are never placed around fields. + * escapechar - specifies a one-character string used to escape + the delimiter when quoting is set to QUOTE_NONE. + * doublequote - controls the handling of quotes inside fields. When + True, two consecutive quotes are interpreted as one during read, + and when writing, each quote character embedded in the data is + written as two quotes """ import re import types -from _csv import Error, __version__, writer, reader, register_dialect, \ +from _csv import Error, writer, reader, register_dialect, \ unregister_dialect, get_dialect, list_dialects, \ field_size_limit, \ QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE, \ - QUOTE_STRINGS, QUOTE_NOTNULL, \ - __doc__ + QUOTE_STRINGS, QUOTE_NOTNULL from _csv import Dialect as _Dialect from io import StringIO __all__ = ["QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE", "QUOTE_STRINGS", "QUOTE_NOTNULL", - "Error", "Dialect", "__doc__", "excel", "excel_tab", + "Error", "Dialect", "excel", "excel_tab", "field_size_limit", "reader", "writer", "register_dialect", "get_dialect", "list_dialects", "Sniffer", - "unregister_dialect", "__version__", "DictReader", "DictWriter", + "unregister_dialect", "DictReader", "DictWriter", "unix_dialect"] +__version__ = "1.0" + + class Dialect: """Describe a CSV dialect. @@ -51,8 +113,8 @@ def _validate(self): try: _Dialect(self) except TypeError as e: - # We do this for compatibility with py2.3 - raise Error(str(e)) + # Re-raise to get a traceback showing more user code. + raise Error(str(e)) from None class excel(Dialect): """Describe the usual properties of Excel-generated CSV files.""" diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 9a1743da6d..95cf51bf08 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -10,7 +10,7 @@ import gc import pickle from test import support -from test.support import warnings_helper, import_helper, check_disallow_instantiation +from test.support import import_helper, check_disallow_instantiation from itertools import permutations from textwrap import dedent from collections import OrderedDict @@ -28,14 +28,20 @@ class Test_Csv(unittest.TestCase): in TestDialectRegistry. """ def _test_arg_valid(self, ctor, arg): + ctor(arg) self.assertRaises(TypeError, ctor) self.assertRaises(TypeError, ctor, None) - self.assertRaises(TypeError, ctor, arg, bad_attr = 0) - self.assertRaises(TypeError, ctor, arg, delimiter = 0) - self.assertRaises(TypeError, ctor, arg, delimiter = 'XX') + self.assertRaises(TypeError, ctor, arg, bad_attr=0) + self.assertRaises(TypeError, ctor, arg, delimiter='') + self.assertRaises(TypeError, ctor, arg, escapechar='') + self.assertRaises(TypeError, ctor, arg, quotechar='') + self.assertRaises(TypeError, ctor, arg, delimiter='^^') + self.assertRaises(TypeError, ctor, arg, escapechar='^^') + self.assertRaises(TypeError, ctor, arg, quotechar='^^') self.assertRaises(csv.Error, ctor, arg, 'foo') self.assertRaises(TypeError, ctor, arg, delimiter=None) self.assertRaises(TypeError, ctor, arg, delimiter=1) + self.assertRaises(TypeError, ctor, arg, escapechar=1) self.assertRaises(TypeError, ctor, arg, quotechar=1) self.assertRaises(TypeError, ctor, arg, lineterminator=None) self.assertRaises(TypeError, ctor, arg, lineterminator=1) @@ -46,11 +52,48 @@ def _test_arg_valid(self, ctor, arg): quoting=csv.QUOTE_ALL, quotechar=None) self.assertRaises(TypeError, ctor, arg, quoting=csv.QUOTE_NONE, quotechar='') + self.assertRaises(ValueError, ctor, arg, delimiter='\n') + self.assertRaises(ValueError, ctor, arg, escapechar='\n') + self.assertRaises(ValueError, ctor, arg, quotechar='\n') + self.assertRaises(ValueError, ctor, arg, delimiter='\r') + self.assertRaises(ValueError, ctor, arg, escapechar='\r') + self.assertRaises(ValueError, ctor, arg, quotechar='\r') + ctor(arg, delimiter=' ') + ctor(arg, escapechar=' ') + ctor(arg, quotechar=' ') + ctor(arg, delimiter='\t', skipinitialspace=True) + ctor(arg, escapechar='\t', skipinitialspace=True) + ctor(arg, quotechar='\t', skipinitialspace=True) + ctor(arg, delimiter=' ', skipinitialspace=True) + self.assertRaises(ValueError, ctor, arg, + escapechar=' ', skipinitialspace=True) + self.assertRaises(ValueError, ctor, arg, + quotechar=' ', skipinitialspace=True) + ctor(arg, delimiter='^') + ctor(arg, escapechar='^') + ctor(arg, quotechar='^') + self.assertRaises(ValueError, ctor, arg, delimiter='^', escapechar='^') + self.assertRaises(ValueError, ctor, arg, delimiter='^', quotechar='^') + self.assertRaises(ValueError, ctor, arg, escapechar='^', quotechar='^') + ctor(arg, delimiter='\x85') + ctor(arg, escapechar='\x85') + ctor(arg, quotechar='\x85') + ctor(arg, lineterminator='\x85') + self.assertRaises(ValueError, ctor, arg, + delimiter='\x85', lineterminator='\x85') + self.assertRaises(ValueError, ctor, arg, + escapechar='\x85', lineterminator='\x85') + self.assertRaises(ValueError, ctor, arg, + quotechar='\x85', lineterminator='\x85') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_reader_arg_valid(self): self._test_arg_valid(csv.reader, []) self.assertRaises(OSError, csv.reader, BadIterable()) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_writer_arg_valid(self): self._test_arg_valid(csv.writer, StringIO()) class BadWriter: @@ -150,13 +193,8 @@ def _write_error_test(self, exc, fields, **kwargs): fileobj.seek(0) self.assertEqual(fileobj.read(), '') - # TODO: RUSTPYTHON ''\r\n to ""\r\n unsupported - @unittest.expectedFailure def test_write_arg_valid(self): self._write_error_test(csv.Error, None) - self._write_test((), '') - self._write_test([None], '""') - self._write_error_test(csv.Error, [None], quoting = csv.QUOTE_NONE) # Check that exceptions are passed up the chain self._write_error_test(OSError, BadIterable()) class BadList: @@ -170,14 +208,13 @@ class BadItem: def __str__(self): raise OSError self._write_error_test(OSError, [BadItem()]) - def test_write_bigfield(self): # This exercises the buffer realloc functionality bigstring = 'X' * 50000 self._write_test([bigstring,bigstring], '%s,%s' % \ (bigstring, bigstring)) - # TODO: RUSTPYTHON quoting style check is unsupported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_quoting(self): self._write_test(['a',1,'p,q'], 'a,1,"p,q"') @@ -196,7 +233,7 @@ def test_write_quoting(self): self._write_test(['a','',None,1], '"a","",,"1"', quoting = csv.QUOTE_NOTNULL) - # TODO: RUSTPYTHON doublequote check is unsupported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_escape(self): self._write_test(['a',1,'p,q'], 'a,1,"p,q"', @@ -229,7 +266,7 @@ def test_write_escape(self): self._write_test(['C\\', '6', '7', 'X"'], 'C\\\\,6,7,"X"""', escapechar='\\', quoting=csv.QUOTE_MINIMAL) - # TODO: RUSTPYTHON lineterminator double char unsupported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_lineterminator(self): for lineterminator in '\r\n', '\n', '\r', '!@#', '\0': @@ -238,11 +275,13 @@ def test_write_lineterminator(self): writer = csv.writer(sio, lineterminator=lineterminator) writer.writerow(['a', 'b']) writer.writerow([1, 2]) + writer.writerow(['\r', '\n']) self.assertEqual(sio.getvalue(), f'a,b{lineterminator}' - f'1,2{lineterminator}') + f'1,2{lineterminator}' + f'"\r","\n"{lineterminator}') - # TODO: RUSTPYTHON ''\r\n to ""\r\n unspported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_iterable(self): self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"') @@ -285,6 +324,53 @@ def test_writerows_with_none(self): fileobj.seek(0) self.assertEqual(fileobj.read(), 'a\r\n""\r\n') + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_write_empty_fields(self): + self._write_test((), '') + self._write_test([''], '""') + self._write_error_test(csv.Error, [''], quoting=csv.QUOTE_NONE) + self._write_test([''], '""', quoting=csv.QUOTE_STRINGS) + self._write_test([''], '""', quoting=csv.QUOTE_NOTNULL) + self._write_test([None], '""') + self._write_error_test(csv.Error, [None], quoting=csv.QUOTE_NONE) + self._write_error_test(csv.Error, [None], quoting=csv.QUOTE_STRINGS) + self._write_error_test(csv.Error, [None], quoting=csv.QUOTE_NOTNULL) + self._write_test(['', ''], ',') + self._write_test([None, None], ',') + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_write_empty_fields_space_delimiter(self): + self._write_test([''], '""', delimiter=' ', skipinitialspace=False) + self._write_test([''], '""', delimiter=' ', skipinitialspace=True) + self._write_test([None], '""', delimiter=' ', skipinitialspace=False) + self._write_test([None], '""', delimiter=' ', skipinitialspace=True) + + self._write_test(['', ''], ' ', delimiter=' ', skipinitialspace=False) + self._write_test(['', ''], '"" ""', delimiter=' ', skipinitialspace=True) + self._write_test([None, None], ' ', delimiter=' ', skipinitialspace=False) + self._write_test([None, None], '"" ""', delimiter=' ', skipinitialspace=True) + + self._write_test(['', ''], ' ', delimiter=' ', skipinitialspace=False, + quoting=csv.QUOTE_NONE) + self._write_error_test(csv.Error, ['', ''], + delimiter=' ', skipinitialspace=True, + quoting=csv.QUOTE_NONE) + for quoting in csv.QUOTE_STRINGS, csv.QUOTE_NOTNULL: + self._write_test(['', ''], '"" ""', delimiter=' ', skipinitialspace=False, + quoting=quoting) + self._write_test(['', ''], '"" ""', delimiter=' ', skipinitialspace=True, + quoting=quoting) + + for quoting in csv.QUOTE_NONE, csv.QUOTE_STRINGS, csv.QUOTE_NOTNULL: + self._write_test([None, None], ' ', delimiter=' ', skipinitialspace=False, + quoting=quoting) + self._write_error_test(csv.Error, [None, None], + delimiter=' ', skipinitialspace=True, + quoting=quoting) + def test_writerows_errors(self): with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: writer = csv.writer(fileobj) @@ -296,7 +382,7 @@ def _read_test(self, input, expect, **kwargs): result = list(reader) self.assertEqual(result, expect) - # TODO RUSTPYTHON strict mode is unsupported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_oddinputs(self): self._read_test([], []) @@ -308,16 +394,23 @@ def test_read_oddinputs(self): self.assertRaises(csv.Error, self._read_test, [b'abc'], None) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_read_eol(self): - self._read_test(['a,b'], [['a','b']]) - self._read_test(['a,b\n'], [['a','b']]) - self._read_test(['a,b\r\n'], [['a','b']]) - self._read_test(['a,b\r'], [['a','b']]) - self.assertRaises(csv.Error, self._read_test, ['a,b\rc,d'], []) - self.assertRaises(csv.Error, self._read_test, ['a,b\nc,d'], []) - self.assertRaises(csv.Error, self._read_test, ['a,b\r\nc,d'], []) - - # TODO RUSTPYTHON double quote umimplement + self._read_test(['a,b', 'c,d'], [['a','b'], ['c','d']]) + self._read_test(['a,b\n', 'c,d\n'], [['a','b'], ['c','d']]) + self._read_test(['a,b\r\n', 'c,d\r\n'], [['a','b'], ['c','d']]) + self._read_test(['a,b\r', 'c,d\r'], [['a','b'], ['c','d']]) + + errmsg = "with newline=''" + with self.assertRaisesRegex(csv.Error, errmsg): + next(csv.reader(['a,b\rc,d'])) + with self.assertRaisesRegex(csv.Error, errmsg): + next(csv.reader(['a,b\nc,d'])) + with self.assertRaisesRegex(csv.Error, errmsg): + next(csv.reader(['a,b\r\nc,d'])) + + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_eof(self): self._read_test(['a,"'], [['a', '']]) @@ -328,7 +421,7 @@ def test_read_eof(self): self.assertRaises(csv.Error, self._read_test, ['^'], [], escapechar='^', strict=True) - # TODO RUSTPYTHON + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_nul(self): self._read_test(['\0'], [['\0']]) @@ -342,7 +435,7 @@ def test_read_delimiter(self): self._read_test(['a;b;c'], [['a', 'b', 'c']], delimiter=';') self._read_test(['a\0b\0c'], [['a', 'b', 'c']], delimiter='\0') - # TODO RUSTPYTHON + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_escape(self): self._read_test(['a,\\b,c'], [['a', 'b', 'c']], escapechar='\\') @@ -356,7 +449,7 @@ def test_read_escape(self): self._read_test(['a,\\b,c'], [['a', '\\b', 'c']], escapechar=None) self._read_test(['a,\\b,c'], [['a', '\\b', 'c']]) - # TODO RUSTPYTHON escapechar unsupported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_quoting(self): self._read_test(['1,",3,",5'], [['1', ',3,', '5']]) @@ -367,17 +460,58 @@ def test_read_quoting(self): # will this fail where locale uses comma for decimals? self._read_test([',3,"5",7.3, 9'], [['', 3, '5', 7.3, 9]], quoting=csv.QUOTE_NONNUMERIC) + self._read_test([',3,"5",7.3, 9'], [[None, '3', '5', '7.3', ' 9']], + quoting=csv.QUOTE_NOTNULL) + self._read_test([',3,"5",7.3, 9'], [[None, 3, '5', 7.3, 9]], + quoting=csv.QUOTE_STRINGS) + + self._read_test([',,"",'], [['', '', '', '']]) + self._read_test([',,"",'], [['', '', '', '']], + quoting=csv.QUOTE_NONNUMERIC) + self._read_test([',,"",'], [[None, None, '', None]], + quoting=csv.QUOTE_NOTNULL) + self._read_test([',,"",'], [[None, None, '', None]], + quoting=csv.QUOTE_STRINGS) + self._read_test(['"a\nb", 7'], [['a\nb', ' 7']]) self.assertRaises(ValueError, self._read_test, ['abc,3'], [[]], quoting=csv.QUOTE_NONNUMERIC) + self.assertRaises(ValueError, self._read_test, + ['abc,3'], [[]], + quoting=csv.QUOTE_STRINGS) self._read_test(['1,@,3,@,5'], [['1', ',3,', '5']], quotechar='@') self._read_test(['1,\0,3,\0,5'], [['1', ',3,', '5']], quotechar='\0') + self._read_test(['1\\.5,\\.5,.5'], [[1.5, 0.5, 0.5]], + quoting=csv.QUOTE_NONNUMERIC, escapechar='\\') + self._read_test(['1\\.5,\\.5,"\\.5"'], [[1.5, 0.5, ".5"]], + quoting=csv.QUOTE_STRINGS, escapechar='\\') + # TODO: RUSTPYTHON; panic + @unittest.skip("TODO: RUSTPYTHON; slice index starts at 1 but ends at 0") def test_read_skipinitialspace(self): self._read_test(['no space, space, spaces,\ttab'], [['no space', 'space', 'spaces', '\ttab']], skipinitialspace=True) + self._read_test([' , , '], + [['', '', '']], + skipinitialspace=True) + self._read_test([' , , '], + [[None, None, None]], + skipinitialspace=True, quoting=csv.QUOTE_NOTNULL) + self._read_test([' , , '], + [[None, None, None]], + skipinitialspace=True, quoting=csv.QUOTE_STRINGS) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_read_space_delimiter(self): + self._read_test(['a b', ' a ', ' ', ''], + [['a', '', '', 'b'], ['', '', 'a', '', ''], ['', '', ''], []], + delimiter=' ', skipinitialspace=False) + self._read_test(['a b', ' a ', ' ', ''], + [['a', 'b'], ['a', ''], [''], []], + delimiter=' ', skipinitialspace=True) def test_read_bigfield(self): # This exercises the buffer realloc functionality and field size @@ -410,27 +544,49 @@ def test_read_linenum(self): self.assertRaises(StopIteration, next, r) self.assertEqual(r.line_num, 3) - # TODO: RUSTPYTHON only '\r\n' unsupported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_roundtrip_quoteed_newlines(self): - with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: - writer = csv.writer(fileobj) - rows = [['a\nb','b'],['c','x\r\nd']] - writer.writerows(rows) - fileobj.seek(0) - for i, row in enumerate(csv.reader(fileobj)): - self.assertEqual(row, rows[i]) + rows = [ + ['\na', 'b\nc', 'd\n'], + ['\re', 'f\rg', 'h\r'], + ['\r\ni', 'j\r\nk', 'l\r\n'], + ['\n\rm', 'n\n\ro', 'p\n\r'], + ['\r\rq', 'r\r\rs', 't\r\r'], + ['\n\nu', 'v\n\nw', 'x\n\n'], + ] + for lineterminator in '\r\n', '\n', '\r': + with self.subTest(lineterminator=lineterminator): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, lineterminator=lineterminator) + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj)): + self.assertEqual(row, rows[i]) - # TODO: RUSTPYTHON only '\r\n' unsupported + # TODO: RUSTPYTHON @unittest.expectedFailure def test_roundtrip_escaped_unquoted_newlines(self): - with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: - writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\") - rows = [['a\nb','b'],['c','x\r\nd']] - writer.writerows(rows) - fileobj.seek(0) - for i, row in enumerate(csv.reader(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")): - self.assertEqual(row,rows[i]) + rows = [ + ['\na', 'b\nc', 'd\n'], + ['\re', 'f\rg', 'h\r'], + ['\r\ni', 'j\r\nk', 'l\r\n'], + ['\n\rm', 'n\n\ro', 'p\n\r'], + ['\r\rq', 'r\r\rs', 't\r\r'], + ['\n\nu', 'v\n\nw', 'x\n\n'], + ] + for lineterminator in '\r\n', '\n', '\r': + with self.subTest(lineterminator=lineterminator): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, lineterminator=lineterminator, + quoting=csv.QUOTE_NONE, escapechar="\\") + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj, + quoting=csv.QUOTE_NONE, + escapechar="\\")): + self.assertEqual(row, rows[i]) + class TestDialectRegistry(unittest.TestCase): def test_registry_badargs(self): @@ -509,10 +665,10 @@ class space(csv.excel): escapechar = "\\" with TemporaryFile("w+", encoding="utf-8") as fileobj: - fileobj.write("abc def\nc1ccccc1 benzene\n") + fileobj.write("abc def\nc1ccccc1 benzene\n") fileobj.seek(0) reader = csv.reader(fileobj, dialect=space()) - self.assertEqual(next(reader), ["abc", "def"]) + self.assertEqual(next(reader), ["abc", "", "", "def"]) self.assertEqual(next(reader), ["c1ccccc1", "benzene"]) def compare_dialect_123(self, expected, *writeargs, **kwwriteargs): @@ -556,14 +712,6 @@ class unspecified(): finally: csv.unregister_dialect('testC') - def test_bad_dialect(self): - # Unknown parameter - self.assertRaises(TypeError, csv.reader, [], bad_attr = 0) - # Bad values - self.assertRaises(TypeError, csv.reader, [], delimiter = None) - self.assertRaises(TypeError, csv.reader, [], quoting = -1) - self.assertRaises(TypeError, csv.reader, [], quoting = 100) - def test_copy(self): for name in csv.list_dialects(): dialect = csv.get_dialect(name) @@ -657,7 +805,7 @@ def test_quoted_quote(self): '"I see," said the blind man', 'as he picked up his hammer and saw']]) - # Rustpython TODO + # TODO: RUSTPYTHON @unittest.expectedFailure def test_quoted_nl(self): input = '''\ @@ -699,12 +847,12 @@ class EscapedExcel(csv.excel): class TestEscapedExcel(TestCsvBase): dialect = EscapedExcel() - # TODO RUSTPYTHON + # TODO: RUSTPYTHON @unittest.expectedFailure def test_escape_fieldsep(self): self.writerAssertEqual([['abc,def']], 'abc\\,def\r\n') - # TODO RUSTPYTHON + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_escape_fieldsep(self): self.readerAssertEqual('abc\\,def\r\n', [['abc,def']]) @@ -712,7 +860,7 @@ def test_read_escape_fieldsep(self): class TestDialectUnix(TestCsvBase): dialect = 'unix' - # TODO RUSTPYTHON + # TODO: RUSTPYTHON @unittest.expectedFailure def test_simple_writer(self): self.writerAssertEqual([[1, 'abc def', 'abc']], '"1","abc def","abc"\n') @@ -730,7 +878,7 @@ class TestQuotedEscapedExcel(TestCsvBase): def test_write_escape_fieldsep(self): self.writerAssertEqual([['abc,def']], '"abc,def"\r\n') - # TODO RUSTPYTHON + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_escape_fieldsep(self): self.readerAssertEqual('"abc\\,def"\r\n', [['abc,def']]) @@ -928,7 +1076,7 @@ def test_read_multi(self): "s1": 'abc', "s2": 'def'}) - # TODO RUSTPYTHON + # TODO: RUSTPYTHON @unittest.expectedFailure def test_read_with_blanks(self): reader = csv.DictReader(["1,2,abc,4,5,6\r\n","\r\n", @@ -981,9 +1129,11 @@ def test_float_write(self): fileobj.seek(0) self.assertEqual(fileobj.read(), expected) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_char_write(self): import array, string - a = array.array('u', string.ascii_letters) + a = array.array('w', string.ascii_letters) with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: writer = csv.writer(fileobj, dialect="excel") @@ -1007,6 +1157,12 @@ class mydialect(csv.Dialect): mydialect.quoting = None self.assertRaises(csv.Error, mydialect) + mydialect.quoting = 42 + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + 'bad "quoting" value') + mydialect.doublequote = True mydialect.quoting = csv.QUOTE_ALL mydialect.quotechar = '"' @@ -1122,11 +1278,18 @@ class mydialect(csv.Dialect): self.assertEqual(str(cm.exception), '"lineterminator" must be a string') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_invalid_chars(self): - def create_invalid(field_name, value): + def create_invalid(field_name, value, **kwargs): class mydialect(csv.Dialect): - pass + delimiter = ',' + quoting = csv.QUOTE_ALL + quotechar = '"' + lineterminator = '\r\n' setattr(mydialect, field_name, value) + for field_name, value in kwargs.items(): + setattr(mydialect, field_name, value) d = mydialect() for field_name in ("delimiter", "escapechar", "quotechar"): @@ -1135,6 +1298,11 @@ class mydialect(csv.Dialect): self.assertRaises(csv.Error, create_invalid, field_name, "abc") self.assertRaises(csv.Error, create_invalid, field_name, b'x') self.assertRaises(csv.Error, create_invalid, field_name, 5) + self.assertRaises(ValueError, create_invalid, field_name, "\n") + self.assertRaises(ValueError, create_invalid, field_name, "\r") + if field_name != "delimiter": + self.assertRaises(ValueError, create_invalid, field_name, " ", + skipinitialspace=True) class TestSniffer(unittest.TestCase): @@ -1451,8 +1619,7 @@ def test_ordered_dict_reader(self): class MiscTestCase(unittest.TestCase): def test__all__(self): - extra = {'__doc__', '__version__'} - support.check__all__(self, csv, ('csv', '_csv'), extra=extra) + support.check__all__(self, csv, ('csv', '_csv')) def test_subclassable(self): # issue 44089 diff --git a/stdlib/src/csv.rs b/stdlib/src/csv.rs index bded06e37f..5f6d249476 100644 --- a/stdlib/src/csv.rs +++ b/stdlib/src/csv.rs @@ -8,6 +8,7 @@ mod _csv { builtins::{PyBaseExceptionRef, PyInt, PyNone, PyStr, PyType, PyTypeError, PyTypeRef}, function::{ArgIterable, ArgumentError, FromArgs, FuncArgs, OptionalArg}, protocol::{PyIter, PyIterReturn}, + raise_if_stop, types::{Constructor, IterNext, Iterable, SelfIter}, }; use csv_core::Terminator; @@ -923,10 +924,7 @@ mod _csv { impl SelfIter for Reader {} impl IterNext for Reader { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { - let string = match zelf.iter.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let string = raise_if_stop!(zelf.iter.next(vm)?); let string = string.downcast::().map_err(|obj| { new_csv_error( vm, From 4079776c36728deea83e49f609d2443063ef187c Mon Sep 17 00:00:00 2001 From: Jack O'Connor Date: Fri, 25 Jul 2025 16:42:22 +0100 Subject: [PATCH 4/4] Add `kde` function and tests to RustPython statistics module (#6030) * Copy CPython 3.13 statistics module into RustPython * Adjust CPython "magic constants" in KDE tests ## test_kde I'm not too sure why but this one takes a few seconds to run the second for loop which calculates the cumulative distribution and does a rough calculation of the area under the curve. ## test_kde_random I have a lower bound for RustPython to sort a random list of 1_000_000 numbers on my laptop of > 1 hour. By dropping n to 30_000 sort will not take an egregious amount of time to run. It is then necessary to lower the tolerance for the math.isclose check, or the computed values may **randomly** fail due to the higher variance caused by the smaller sample size. * Reintroduce expected failure in test_statistics.TestNormalDict.test_slots * Sync Rust `normal_dist_inv_cdf` with Python equivalent See https://github.com/python/cpython/pull/95265. To quote: > Restores alignment with random.gauss(mu, sigma) and random.normalvariate(mu, sigma) both. of which are equivalent to sampling from NormalDist(mu, sigma).inv_cdf(random()). The two functions in the random module happy accept sigma=0 and give a well-defined result. > This also lets the function gently handle a sigma getting smaller, eventually becoming zero. As sigma decrease, NormalDist(mu, sigma).inv_cdf(p) forms a tighter and tighter internal around mu and becoming exactly mu in the limit. For example, NormalDist(100, 1E-300).inv_cdf(0.3) cleanly evaluates to 100.0but withsigma=1e-500`` the function previously would raised an unexpected error. --- Lib/statistics.py | 910 ++++++++++++++++++++++++++++-------- Lib/test/test_statistics.py | 533 ++++++++++++++++++--- stdlib/src/statistics.rs | 2 +- 3 files changed, 1193 insertions(+), 252 deletions(-) diff --git a/Lib/statistics.py b/Lib/statistics.py index f66245380a..ad4a94219c 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -11,7 +11,7 @@ Function Description ================== ================================================== mean Arithmetic mean (average) of data. -fmean Fast, floating point arithmetic mean. +fmean Fast, floating-point arithmetic mean. geometric_mean Geometric mean of data. harmonic_mean Harmonic mean of data. median Median (middle value) of data. @@ -112,6 +112,8 @@ 'fmean', 'geometric_mean', 'harmonic_mean', + 'kde', + 'kde_random', 'linear_regression', 'mean', 'median', @@ -130,14 +132,20 @@ import math import numbers import random +import sys from fractions import Fraction from decimal import Decimal -from itertools import groupby, repeat +from itertools import count, groupby, repeat from bisect import bisect_left, bisect_right -from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum +from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod +from math import isfinite, isinf, pi, cos, sin, tan, cosh, asin, atan, acos +from functools import reduce from operator import itemgetter -from collections import Counter, namedtuple +from collections import Counter, namedtuple, defaultdict + +_SQRT2 = sqrt(2.0) +_random = random # === Exceptions === @@ -180,11 +188,12 @@ def _sum(data): allowed. """ count = 0 + types = set() + types_add = types.add partials = {} partials_get = partials.get - T = int for typ, values in groupby(data, type): - T = _coerce(T, typ) # or raise TypeError + types_add(typ) for n, d in map(_exact_ratio, values): count += 1 partials[d] = partials_get(d, 0) + n @@ -196,9 +205,51 @@ def _sum(data): else: # Sum all the partial sums using builtin sum. total = sum(Fraction(n, d) for d, n in partials.items()) + T = reduce(_coerce, types, int) # or raise TypeError return (T, total, count) +def _ss(data, c=None): + """Return the exact mean and sum of square deviations of sequence data. + + Calculations are done in a single pass, allowing the input to be an iterator. + + If given *c* is used the mean; otherwise, it is calculated from the data. + Use the *c* argument with care, as it can lead to garbage results. + + """ + if c is not None: + T, ssd, count = _sum((d := x - c) * d for x in data) + return (T, ssd, c, count) + count = 0 + types = set() + types_add = types.add + sx_partials = defaultdict(int) + sxx_partials = defaultdict(int) + for typ, values in groupby(data, type): + types_add(typ) + for n, d in map(_exact_ratio, values): + count += 1 + sx_partials[d] += n + sxx_partials[d] += n * n + if not count: + ssd = c = Fraction(0) + elif None in sx_partials: + # The sum will be a NAN or INF. We can ignore all the finite + # partials, and just look at this special one. + ssd = c = sx_partials[None] + assert not _isfinite(ssd) + else: + sx = sum(Fraction(n, d) for d, n in sx_partials.items()) + sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items()) + # This formula has poor numeric properties for floats, + # but with fractions it is exact. + ssd = (count * sxx - sx * sx) / count + c = sx / count + T = reduce(_coerce, types, int) # or raise TypeError + return (T, ssd, c, count) + + def _isfinite(x): try: return x.is_finite() # Likely a Decimal. @@ -245,6 +296,28 @@ def _exact_ratio(x): x is expected to be an int, Fraction, Decimal or float. """ + + # XXX We should revisit whether using fractions to accumulate exact + # ratios is the right way to go. + + # The integer ratios for binary floats can have numerators or + # denominators with over 300 decimal digits. The problem is more + # acute with decimal floats where the default decimal context + # supports a huge range of exponents from Emin=-999999 to + # Emax=999999. When expanded with as_integer_ratio(), numbers like + # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large + # numerators or denominators that will slow computation. + + # When the integer ratios are accumulated as fractions, the size + # grows to cover the full range from the smallest magnitude to the + # largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300), + # has a 616 digit numerator. Likewise, + # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000')) + # has 10,003 digit numerator. + + # This doesn't seem to have been problem in practice, but it is a + # potential pitfall. + try: return x.as_integer_ratio() except AttributeError: @@ -279,22 +352,6 @@ def _convert(value, T): raise -def _find_lteq(a, x): - 'Locate the leftmost value exactly equal to x' - i = bisect_left(a, x) - if i != len(a) and a[i] == x: - return i - raise ValueError - - -def _find_rteq(a, l, x): - 'Locate the rightmost value exactly equal to x' - i = bisect_right(a, x, lo=l) - if i != (len(a) + 1) and a[i - 1] == x: - return i - 1 - raise ValueError - - def _fail_neg(values, errmsg='negative value'): """Iterate over values, failing if any are less than zero.""" for x in values: @@ -303,6 +360,113 @@ def _fail_neg(values, errmsg='negative value'): yield x +def _rank(data, /, *, key=None, reverse=False, ties='average', start=1) -> list[float]: + """Rank order a dataset. The lowest value has rank 1. + + Ties are averaged so that equal values receive the same rank: + + >>> data = [31, 56, 31, 25, 75, 18] + >>> _rank(data) + [3.5, 5.0, 3.5, 2.0, 6.0, 1.0] + + The operation is idempotent: + + >>> _rank([3.5, 5.0, 3.5, 2.0, 6.0, 1.0]) + [3.5, 5.0, 3.5, 2.0, 6.0, 1.0] + + It is possible to rank the data in reverse order so that the + highest value has rank 1. Also, a key-function can extract + the field to be ranked: + + >>> goals = [('eagles', 45), ('bears', 48), ('lions', 44)] + >>> _rank(goals, key=itemgetter(1), reverse=True) + [2.0, 1.0, 3.0] + + Ranks are conventionally numbered starting from one; however, + setting *start* to zero allows the ranks to be used as array indices: + + >>> prize = ['Gold', 'Silver', 'Bronze', 'Certificate'] + >>> scores = [8.1, 7.3, 9.4, 8.3] + >>> [prize[int(i)] for i in _rank(scores, start=0, reverse=True)] + ['Bronze', 'Certificate', 'Gold', 'Silver'] + + """ + # If this function becomes public at some point, more thought + # needs to be given to the signature. A list of ints is + # plausible when ties is "min" or "max". When ties is "average", + # either list[float] or list[Fraction] is plausible. + + # Default handling of ties matches scipy.stats.mstats.spearmanr. + if ties != 'average': + raise ValueError(f'Unknown tie resolution method: {ties!r}') + if key is not None: + data = map(key, data) + val_pos = sorted(zip(data, count()), reverse=reverse) + i = start - 1 + result = [0] * len(val_pos) + for _, g in groupby(val_pos, key=itemgetter(0)): + group = list(g) + size = len(group) + rank = i + (size + 1) / 2 + for value, orig_pos in group: + result[orig_pos] = rank + i += size + return result + + +def _integer_sqrt_of_frac_rto(n: int, m: int) -> int: + """Square root of n/m, rounded to the nearest integer using round-to-odd.""" + # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf + a = math.isqrt(n // m) + return a | (a*a*m != n) + + +# For 53 bit precision floats, the bit width used in +# _float_sqrt_of_frac() is 109. +_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3 + + +def _float_sqrt_of_frac(n: int, m: int) -> float: + """Square root of n/m as a float, correctly rounded.""" + # See principle and proof sketch at: https://bugs.python.org/msg407078 + q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2 + if q >= 0: + numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q + denominator = 1 + else: + numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m) + denominator = 1 << -q + return numerator / denominator # Convert to float + + +def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal: + """Square root of n/m as a Decimal, correctly rounded.""" + # Premise: For decimal, computing (n/m).sqrt() can be off + # by 1 ulp from the correctly rounded result. + # Method: Check the result, moving up or down a step if needed. + if n <= 0: + if not n: + return Decimal('0.0') + n, m = -n, -m + + root = (Decimal(n) / Decimal(m)).sqrt() + nr, dr = root.as_integer_ratio() + + plus = root.next_plus() + np, dp = plus.as_integer_ratio() + # test: n / m > ((root + plus) / 2) ** 2 + if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2: + return plus + + minus = root.next_minus() + nm, dm = minus.as_integer_ratio() + # test: n / m < ((root + minus) / 2) ** 2 + if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2: + return minus + + return root + + # === Measures of central tendency (averages) === def mean(data): @@ -321,17 +485,13 @@ def mean(data): If ``data`` is empty, StatisticsError will be raised. """ - if iter(data) is data: - data = list(data) - n = len(data) + T, total, n = _sum(data) if n < 1: raise StatisticsError('mean requires at least one data point') - T, total, count = _sum(data) - assert count == n return _convert(total / n, T) -def fmean(data): +def fmean(data, weights=None): """Convert data to floats and compute the arithmetic mean. This runs faster than the mean() function and it always returns a float. @@ -340,29 +500,40 @@ def fmean(data): >>> fmean([3.5, 4.0, 5.25]) 4.25 """ - try: - n = len(data) - except TypeError: - # Handle iterators that do not define __len__(). - n = 0 - def count(iterable): - nonlocal n - for n, x in enumerate(iterable, start=1): - yield x - total = fsum(count(data)) - else: + if weights is None: + try: + n = len(data) + except TypeError: + # Handle iterators that do not define __len__(). + n = 0 + def count(iterable): + nonlocal n + for n, x in enumerate(iterable, start=1): + yield x + data = count(data) total = fsum(data) - try: + if not n: + raise StatisticsError('fmean requires at least one data point') return total / n - except ZeroDivisionError: - raise StatisticsError('fmean requires at least one data point') from None + if not isinstance(weights, (list, tuple)): + weights = list(weights) + try: + num = sumprod(data, weights) + except ValueError: + raise StatisticsError('data and weights must be the same length') + den = fsum(weights) + if not den: + raise StatisticsError('sum of weights must be non-zero') + return num / den def geometric_mean(data): """Convert data to floats and compute the geometric mean. - Raises a StatisticsError if the input dataset is empty, - if it contains a zero, or if it contains a negative value. + Raises a StatisticsError if the input dataset is empty + or if it contains a negative value. + + Returns zero if the product of inputs is zero. No special efforts are made to achieve exact results. (However, this may change in the future.) @@ -370,11 +541,25 @@ def geometric_mean(data): >>> round(geometric_mean([54, 24, 36]), 9) 36.0 """ - try: - return exp(fmean(map(log, data))) - except ValueError: - raise StatisticsError('geometric mean requires a non-empty dataset ' - 'containing positive numbers') from None + n = 0 + found_zero = False + def count_positive(iterable): + nonlocal n, found_zero + for n, x in enumerate(iterable, start=1): + if x > 0.0 or math.isnan(x): + yield x + elif x == 0.0: + found_zero = True + else: + raise StatisticsError('No negative inputs allowed', x) + total = fsum(map(log, count_positive(data))) + if not n: + raise StatisticsError('Must have a non-empty dataset') + if math.isnan(total): + return math.nan + if found_zero: + return math.nan if total == math.inf else 0.0 + return exp(total / n) def harmonic_mean(data, weights=None): @@ -498,58 +683,75 @@ def median_high(data): return data[n // 2] -def median_grouped(data, interval=1): - """Return the 50th percentile (median) of grouped continuous data. +def median_grouped(data, interval=1.0): + """Estimates the median for numeric data binned around the midpoints + of consecutive, fixed-width intervals. - >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5]) - 3.7 - >>> median_grouped([52, 52, 53, 54]) - 52.5 + The *data* can be any iterable of numeric data with each value being + exactly the midpoint of a bin. At least one value must be present. - This calculates the median as the 50th percentile, and should be - used when your data is continuous and grouped. In the above example, - the values 1, 2, 3, etc. actually represent the midpoint of classes - 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in - class 3.5-4.5, and interpolation is used to estimate it. + The *interval* is width of each bin. - Optional argument ``interval`` represents the class interval, and - defaults to 1. Changing the class interval naturally will change the - interpolated 50th percentile value: + For example, demographic information may have been summarized into + consecutive ten-year age groups with each group being represented + by the 5-year midpoints of the intervals: - >>> median_grouped([1, 3, 3, 5, 7], interval=1) - 3.25 - >>> median_grouped([1, 3, 3, 5, 7], interval=2) - 3.5 + >>> demographics = Counter({ + ... 25: 172, # 20 to 30 years old + ... 35: 484, # 30 to 40 years old + ... 45: 387, # 40 to 50 years old + ... 55: 22, # 50 to 60 years old + ... 65: 6, # 60 to 70 years old + ... }) + + The 50th percentile (median) is the 536th person out of the 1071 + member cohort. That person is in the 30 to 40 year old age group. + + The regular median() function would assume that everyone in the + tricenarian age group was exactly 35 years old. A more tenable + assumption is that the 484 members of that age group are evenly + distributed between 30 and 40. For that, we use median_grouped(). + + >>> data = list(demographics.elements()) + >>> median(data) + 35 + >>> round(median_grouped(data, interval=10), 1) + 37.5 + + The caller is responsible for making sure the data points are separated + by exact multiples of *interval*. This is essential for getting a + correct result. The function does not check this precondition. + + Inputs may be any numeric type that can be coerced to a float during + the interpolation step. - This function does not check whether the data points are at least - ``interval`` apart. """ data = sorted(data) n = len(data) - if n == 0: + if not n: raise StatisticsError("no median for empty data") - elif n == 1: - return data[0] + # Find the value at the midpoint. Remember this corresponds to the - # centre of the class interval. + # midpoint of the class interval. x = data[n // 2] - for obj in (x, interval): - if isinstance(obj, (str, bytes)): - raise TypeError('expected number but got %r' % obj) + + # Using O(log n) bisection, find where all the x values occur in the data. + # All x will lie within data[i:j]. + i = bisect_left(data, x) + j = bisect_right(data, x, lo=i) + + # Coerce to floats, raising a TypeError if not possible try: - L = x - interval / 2 # The lower limit of the median interval. - except TypeError: - # Mixed type. For now we just coerce to float. - L = float(x) - float(interval) / 2 - - # Uses bisection search to search for x in data with log(n) time complexity - # Find the position of leftmost occurrence of x in data - l1 = _find_lteq(data, x) - # Find the position of rightmost occurrence of x in data[l1...len(data)] - # Assuming always l1 <= l2 - l2 = _find_rteq(data, l1, x) - cf = l1 - f = l2 - l1 + 1 + interval = float(interval) + x = float(x) + except ValueError: + raise TypeError(f'Value cannot be converted to a float') + + # Interpolate the median using the formula found at: + # https://www.cuemath.com/data/median-of-grouped-data/ + L = x - interval / 2.0 # Lower limit of the median interval + cf = i # Cumulative frequency of the preceding interval + f = j - i # Number of elements in the median internal return L + interval * (n / 2 - cf) / f @@ -596,9 +798,223 @@ def multimode(data): >>> multimode('') [] """ - counts = Counter(iter(data)).most_common() - maxcount, mode_items = next(groupby(counts, key=itemgetter(1)), (0, [])) - return list(map(itemgetter(0), mode_items)) + counts = Counter(iter(data)) + if not counts: + return [] + maxcount = max(counts.values()) + return [value for value, count in counts.items() if count == maxcount] + + +def kde(data, h, kernel='normal', *, cumulative=False): + """Kernel Density Estimation: Create a continuous probability density + function or cumulative distribution function from discrete samples. + + The basic idea is to smooth the data using a kernel function + to help draw inferences about a population from a sample. + + The degree of smoothing is controlled by the scaling parameter h + which is called the bandwidth. Smaller values emphasize local + features while larger values give smoother results. + + The kernel determines the relative weights of the sample data + points. Generally, the choice of kernel shape does not matter + as much as the more influential bandwidth smoothing parameter. + + Kernels that give some weight to every sample point: + + normal (gauss) + logistic + sigmoid + + Kernels that only give weight to sample points within + the bandwidth: + + rectangular (uniform) + triangular + parabolic (epanechnikov) + quartic (biweight) + triweight + cosine + + If *cumulative* is true, will return a cumulative distribution function. + + A StatisticsError will be raised if the data sequence is empty. + + Example + ------- + + Given a sample of six data points, construct a continuous + function that estimates the underlying probability density: + + >>> sample = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2] + >>> f_hat = kde(sample, h=1.5) + + Compute the area under the curve: + + >>> area = sum(f_hat(x) for x in range(-20, 20)) + >>> round(area, 4) + 1.0 + + Plot the estimated probability density function at + evenly spaced points from -6 to 10: + + >>> for x in range(-6, 11): + ... density = f_hat(x) + ... plot = ' ' * int(density * 400) + 'x' + ... print(f'{x:2}: {density:.3f} {plot}') + ... + -6: 0.002 x + -5: 0.009 x + -4: 0.031 x + -3: 0.070 x + -2: 0.111 x + -1: 0.125 x + 0: 0.110 x + 1: 0.086 x + 2: 0.068 x + 3: 0.059 x + 4: 0.066 x + 5: 0.082 x + 6: 0.082 x + 7: 0.058 x + 8: 0.028 x + 9: 0.009 x + 10: 0.002 x + + Estimate P(4.5 < X <= 7.5), the probability that a new sample value + will be between 4.5 and 7.5: + + >>> cdf = kde(sample, h=1.5, cumulative=True) + >>> round(cdf(7.5) - cdf(4.5), 2) + 0.22 + + References + ---------- + + Kernel density estimation and its application: + https://www.itm-conferences.org/articles/itmconf/pdf/2018/08/itmconf_sam2018_00037.pdf + + Kernel functions in common use: + https://en.wikipedia.org/wiki/Kernel_(statistics)#kernel_functions_in_common_use + + Interactive graphical demonstration and exploration: + https://demonstrations.wolfram.com/KernelDensityEstimation/ + + Kernel estimation of cumulative distribution function of a random variable with bounded support + https://www.econstor.eu/bitstream/10419/207829/1/10.21307_stattrans-2016-037.pdf + + """ + + n = len(data) + if not n: + raise StatisticsError('Empty data sequence') + + if not isinstance(data[0], (int, float)): + raise TypeError('Data sequence must contain ints or floats') + + if h <= 0.0: + raise StatisticsError(f'Bandwidth h must be positive, not {h=!r}') + + match kernel: + + case 'normal' | 'gauss': + sqrt2pi = sqrt(2 * pi) + sqrt2 = sqrt(2) + K = lambda t: exp(-1/2 * t * t) / sqrt2pi + W = lambda t: 1/2 * (1.0 + erf(t / sqrt2)) + support = None + + case 'logistic': + # 1.0 / (exp(t) + 2.0 + exp(-t)) + K = lambda t: 1/2 / (1.0 + cosh(t)) + W = lambda t: 1.0 - 1.0 / (exp(t) + 1.0) + support = None + + case 'sigmoid': + # (2/pi) / (exp(t) + exp(-t)) + c1 = 1 / pi + c2 = 2 / pi + K = lambda t: c1 / cosh(t) + W = lambda t: c2 * atan(exp(t)) + support = None + + case 'rectangular' | 'uniform': + K = lambda t: 1/2 + W = lambda t: 1/2 * t + 1/2 + support = 1.0 + + case 'triangular': + K = lambda t: 1.0 - abs(t) + W = lambda t: t*t * (1/2 if t < 0.0 else -1/2) + t + 1/2 + support = 1.0 + + case 'parabolic' | 'epanechnikov': + K = lambda t: 3/4 * (1.0 - t * t) + W = lambda t: -1/4 * t**3 + 3/4 * t + 1/2 + support = 1.0 + + case 'quartic' | 'biweight': + K = lambda t: 15/16 * (1.0 - t * t) ** 2 + W = lambda t: 3/16 * t**5 - 5/8 * t**3 + 15/16 * t + 1/2 + support = 1.0 + + case 'triweight': + K = lambda t: 35/32 * (1.0 - t * t) ** 3 + W = lambda t: 35/32 * (-1/7*t**7 + 3/5*t**5 - t**3 + t) + 1/2 + support = 1.0 + + case 'cosine': + c1 = pi / 4 + c2 = pi / 2 + K = lambda t: c1 * cos(c2 * t) + W = lambda t: 1/2 * sin(c2 * t) + 1/2 + support = 1.0 + + case _: + raise StatisticsError(f'Unknown kernel name: {kernel!r}') + + if support is None: + + def pdf(x): + n = len(data) + return sum(K((x - x_i) / h) for x_i in data) / (n * h) + + def cdf(x): + n = len(data) + return sum(W((x - x_i) / h) for x_i in data) / n + + else: + + sample = sorted(data) + bandwidth = h * support + + def pdf(x): + nonlocal n, sample + if len(data) != n: + sample = sorted(data) + n = len(data) + i = bisect_left(sample, x - bandwidth) + j = bisect_right(sample, x + bandwidth) + supported = sample[i : j] + return sum(K((x - x_i) / h) for x_i in supported) / (n * h) + + def cdf(x): + nonlocal n, sample + if len(data) != n: + sample = sorted(data) + n = len(data) + i = bisect_left(sample, x - bandwidth) + j = bisect_right(sample, x + bandwidth) + supported = sample[i : j] + return sum((W((x - x_i) / h) for x_i in supported), i) / n + + if cumulative: + cdf.__doc__ = f'CDF estimate with {h=!r} and {kernel=!r}' + return cdf + + else: + pdf.__doc__ = f'PDF estimate with {h=!r} and {kernel=!r}' + return pdf # Notes on methods for computing quantiles @@ -659,7 +1075,10 @@ def quantiles(data, *, n=4, method='exclusive'): data = sorted(data) ld = len(data) if ld < 2: - raise StatisticsError('must have at least two data points') + if ld == 1: + return data * (n - 1) + raise StatisticsError('must have at least one data point') + if method == 'inclusive': m = ld - 1 result = [] @@ -668,6 +1087,7 @@ def quantiles(data, *, n=4, method='exclusive'): interpolated = (data[j] * (n - delta) + data[j + 1] * delta) / n result.append(interpolated) return result + if method == 'exclusive': m = ld + 1 result = [] @@ -678,6 +1098,7 @@ def quantiles(data, *, n=4, method='exclusive'): interpolated = (data[j - 1] * (n - delta) + data[j] * delta) / n result.append(interpolated) return result + raise ValueError(f'Unknown method: {method!r}') @@ -685,41 +1106,6 @@ def quantiles(data, *, n=4, method='exclusive'): # See http://mathworld.wolfram.com/Variance.html # http://mathworld.wolfram.com/SampleVariance.html -# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -# -# Under no circumstances use the so-called "computational formula for -# variance", as that is only suitable for hand calculations with a small -# amount of low-precision data. It has terrible numeric properties. -# -# See a comparison of three computational methods here: -# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/ - -def _ss(data, c=None): - """Return sum of square deviations of sequence data. - - If ``c`` is None, the mean is calculated in one pass, and the deviations - from the mean are calculated in a second pass. Otherwise, deviations are - calculated from ``c`` as given. Use the second case with care, as it can - lead to garbage results. - """ - if c is not None: - T, total, count = _sum((x-c)**2 for x in data) - return (T, total) - T, total, count = _sum(data) - mean_n, mean_d = (total / count).as_integer_ratio() - partials = Counter() - for n, d in map(_exact_ratio, data): - diff_n = n * mean_d - d * mean_n - diff_d = d * mean_d - partials[diff_d * diff_d] += diff_n * diff_n - if None in partials: - # The sum will be a NAN or INF. We can ignore all the finite - # partials, and just look at this special one. - total = partials[None] - assert not _isfinite(total) - else: - total = sum(Fraction(n, d) for d, n in partials.items()) - return (T, total) def variance(data, xbar=None): @@ -760,12 +1146,9 @@ def variance(data, xbar=None): Fraction(67, 108) """ - if iter(data) is data: - data = list(data) - n = len(data) + T, ss, c, n = _ss(data, xbar) if n < 2: raise StatisticsError('variance requires at least two data points') - T, ss = _ss(data, xbar) return _convert(ss / (n - 1), T) @@ -804,12 +1187,9 @@ def pvariance(data, mu=None): Fraction(13, 72) """ - if iter(data) is data: - data = list(data) - n = len(data) + T, ss, c, n = _ss(data, mu) if n < 1: raise StatisticsError('pvariance requires at least one data point') - T, ss = _ss(data, mu) return _convert(ss / n, T) @@ -822,14 +1202,13 @@ def stdev(data, xbar=None): 1.0810874155219827 """ - # Fixme: Despite the exact sum of squared deviations, some inaccuracy - # remain because there are two rounding steps. The first occurs in - # the _convert() step for variance(), the second occurs in math.sqrt(). - var = variance(data, xbar) - try: - return var.sqrt() - except AttributeError: - return math.sqrt(var) + T, ss, c, n = _ss(data, xbar) + if n < 2: + raise StatisticsError('stdev requires at least two data points') + mss = ss / (n - 1) + if issubclass(T, Decimal): + return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) + return _float_sqrt_of_frac(mss.numerator, mss.denominator) def pstdev(data, mu=None): @@ -841,14 +1220,47 @@ def pstdev(data, mu=None): 0.986893273527251 """ - # Fixme: Despite the exact sum of squared deviations, some inaccuracy - # remain because there are two rounding steps. The first occurs in - # the _convert() step for pvariance(), the second occurs in math.sqrt(). - var = pvariance(data, mu) + T, ss, c, n = _ss(data, mu) + if n < 1: + raise StatisticsError('pstdev requires at least one data point') + mss = ss / n + if issubclass(T, Decimal): + return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) + return _float_sqrt_of_frac(mss.numerator, mss.denominator) + + +def _mean_stdev(data): + """In one pass, compute the mean and sample standard deviation as floats.""" + T, ss, xbar, n = _ss(data) + if n < 2: + raise StatisticsError('stdev requires at least two data points') + mss = ss / (n - 1) try: - return var.sqrt() + return float(xbar), _float_sqrt_of_frac(mss.numerator, mss.denominator) except AttributeError: - return math.sqrt(var) + # Handle Nans and Infs gracefully + return float(xbar), float(xbar) / float(ss) + +def _sqrtprod(x: float, y: float) -> float: + "Return sqrt(x * y) computed with improved accuracy and without overflow/underflow." + h = sqrt(x * y) + if not isfinite(h): + if isinf(h) and not isinf(x) and not isinf(y): + # Finite inputs overflowed, so scale down, and recompute. + scale = 2.0 ** -512 # sqrt(1 / sys.float_info.max) + return _sqrtprod(scale * x, scale * y) / scale + return h + if not h: + if x and y: + # Non-zero inputs underflowed, so scale up, and recompute. + # Scale: 1 / sqrt(sys.float_info.min * sys.float_info.epsilon) + scale = 2.0 ** 537 + return _sqrtprod(scale * x, scale * y) / scale + return h + # Improve accuracy with a differential correction. + # https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0 + d = sumprod((x, h), (y, -h)) + return h + d / (2.0 * h) # === Statistics for relations between two inputs === @@ -882,18 +1294,16 @@ def covariance(x, y, /): raise StatisticsError('covariance requires at least two data points') xbar = fsum(x) / n ybar = fsum(y) / n - sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y)) + sxy = sumprod((xi - xbar for xi in x), (yi - ybar for yi in y)) return sxy / (n - 1) -def correlation(x, y, /): +def correlation(x, y, /, *, method='linear'): """Pearson's correlation coefficient Return the Pearson's correlation coefficient for two inputs. Pearson's - correlation coefficient *r* takes values between -1 and +1. It measures the - strength and direction of the linear relationship, where +1 means very - strong, positive linear relationship, -1 very strong, negative linear - relationship, and 0 no linear relationship. + correlation coefficient *r* takes values between -1 and +1. It measures + the strength and direction of a linear relationship. >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9] >>> y = [9, 8, 7, 6, 5, 4, 3, 2, 1] @@ -902,19 +1312,36 @@ def correlation(x, y, /): >>> correlation(x, y) -1.0 + If *method* is "ranked", computes Spearman's rank correlation coefficient + for two inputs. The data is replaced by ranks. Ties are averaged + so that equal values receive the same rank. The resulting coefficient + measures the strength of a monotonic relationship. + + Spearman's rank correlation coefficient is appropriate for ordinal + data or for continuous data that doesn't meet the linear proportion + requirement for Pearson's correlation coefficient. """ n = len(x) if len(y) != n: raise StatisticsError('correlation requires that both inputs have same number of data points') if n < 2: raise StatisticsError('correlation requires at least two data points') - xbar = fsum(x) / n - ybar = fsum(y) / n - sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y)) - sxx = fsum((xi - xbar) ** 2.0 for xi in x) - syy = fsum((yi - ybar) ** 2.0 for yi in y) + if method not in {'linear', 'ranked'}: + raise ValueError(f'Unknown method: {method!r}') + if method == 'ranked': + start = (n - 1) / -2 # Center rankings around zero + x = _rank(x, start=start) + y = _rank(y, start=start) + else: + xbar = fsum(x) / n + ybar = fsum(y) / n + x = [xi - xbar for xi in x] + y = [yi - ybar for yi in y] + sxy = sumprod(x, y) + sxx = sumprod(x, x) + syy = sumprod(y, y) try: - return sxy / sqrt(sxx * syy) + return sxy / _sqrtprod(sxx, syy) except ZeroDivisionError: raise StatisticsError('at least one of the inputs is constant') @@ -922,13 +1349,13 @@ def correlation(x, y, /): LinearRegression = namedtuple('LinearRegression', ('slope', 'intercept')) -def linear_regression(x, y, /): +def linear_regression(x, y, /, *, proportional=False): """Slope and intercept for simple linear regression. Return the slope and intercept of simple linear regression parameters estimated using ordinary least squares. Simple linear regression describes relationship between an independent variable - *x* and a dependent variable *y* in terms of linear function: + *x* and a dependent variable *y* in terms of a linear function: y = slope * x + intercept + noise @@ -944,7 +1371,20 @@ def linear_regression(x, y, /): >>> noise = NormalDist().samples(5, seed=42) >>> y = [3 * x[i] + 2 + noise[i] for i in range(5)] >>> linear_regression(x, y) #doctest: +ELLIPSIS - LinearRegression(slope=3.09078914170..., intercept=1.75684970486...) + LinearRegression(slope=3.17495..., intercept=1.00925...) + + If *proportional* is true, the independent variable *x* and the + dependent variable *y* are assumed to be directly proportional. + The data is fit to a line passing through the origin. + + Since the *intercept* will always be 0.0, the underlying linear + function simplifies to: + + y = slope * x + noise + + >>> y = [3 * x[i] + noise[i] for i in range(5)] + >>> linear_regression(x, y, proportional=True) #doctest: +ELLIPSIS + LinearRegression(slope=2.90475..., intercept=0.0) """ n = len(x) @@ -952,15 +1392,18 @@ def linear_regression(x, y, /): raise StatisticsError('linear regression requires that both inputs have same number of data points') if n < 2: raise StatisticsError('linear regression requires at least two data points') - xbar = fsum(x) / n - ybar = fsum(y) / n - sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y)) - sxx = fsum((xi - xbar) ** 2.0 for xi in x) + if not proportional: + xbar = fsum(x) / n + ybar = fsum(y) / n + x = [xi - xbar for xi in x] # List because used three times below + y = (yi - ybar for yi in y) # Generator because only used once below + sxy = sumprod(x, y) + 0.0 # Add zero to coerce result to a float + sxx = sumprod(x, x) try: slope = sxy / sxx # equivalent to: covariance(x, y) / variance(x) except ZeroDivisionError: raise StatisticsError('x is constant') - intercept = ybar - slope * xbar + intercept = 0.0 if proportional else ybar - slope * xbar return LinearRegression(slope=slope, intercept=intercept) @@ -1068,29 +1511,29 @@ def __init__(self, mu=0.0, sigma=1.0): @classmethod def from_samples(cls, data): "Make a normal distribution instance from sample data." - if not isinstance(data, (list, tuple)): - data = list(data) - xbar = fmean(data) - return cls(xbar, stdev(data, xbar)) + return cls(*_mean_stdev(data)) def samples(self, n, *, seed=None): "Generate *n* samples for a given mean and standard deviation." - gauss = random.gauss if seed is None else random.Random(seed).gauss - mu, sigma = self._mu, self._sigma - return [gauss(mu, sigma) for i in range(n)] + rnd = random.random if seed is None else random.Random(seed).random + inv_cdf = _normal_dist_inv_cdf + mu = self._mu + sigma = self._sigma + return [inv_cdf(rnd(), mu, sigma) for _ in repeat(None, n)] def pdf(self, x): "Probability density function. P(x <= X < x+dx) / dx" - variance = self._sigma ** 2.0 + variance = self._sigma * self._sigma if not variance: raise StatisticsError('pdf() not defined when sigma is zero') - return exp((x - self._mu)**2.0 / (-2.0*variance)) / sqrt(tau*variance) + diff = x - self._mu + return exp(diff * diff / (-2.0 * variance)) / sqrt(tau * variance) def cdf(self, x): "Cumulative distribution function. P(X <= x)" if not self._sigma: raise StatisticsError('cdf() not defined when sigma is zero') - return 0.5 * (1.0 + erf((x - self._mu) / (self._sigma * sqrt(2.0)))) + return 0.5 * (1.0 + erf((x - self._mu) / (self._sigma * _SQRT2))) def inv_cdf(self, p): """Inverse cumulative distribution function. x : P(X <= x) = p @@ -1104,8 +1547,6 @@ def inv_cdf(self, p): """ if p <= 0.0 or p >= 1.0: raise StatisticsError('p must be in the range 0.0 < p < 1.0') - if self._sigma <= 0.0: - raise StatisticsError('cdf() not defined when sigma at or below zero') return _normal_dist_inv_cdf(p, self._mu, self._sigma) def quantiles(self, n=4): @@ -1146,9 +1587,9 @@ def overlap(self, other): dv = Y_var - X_var dm = fabs(Y._mu - X._mu) if not dv: - return 1.0 - erf(dm / (2.0 * X._sigma * sqrt(2.0))) + return 1.0 - erf(dm / (2.0 * X._sigma * _SQRT2)) a = X._mu * Y_var - Y._mu * X_var - b = X._sigma * Y._sigma * sqrt(dm**2.0 + dv * log(Y_var / X_var)) + b = X._sigma * Y._sigma * sqrt(dm * dm + dv * log(Y_var / X_var)) x1 = (a + b) / dv x2 = (a - b) / dv return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2))) @@ -1191,7 +1632,7 @@ def stdev(self): @property def variance(self): "Square of the standard deviation." - return self._sigma ** 2.0 + return self._sigma * self._sigma def __add__(x1, x2): """Add a constant or another NormalDist instance. @@ -1265,3 +1706,102 @@ def __hash__(self): def __repr__(self): return f'{type(self).__name__}(mu={self._mu!r}, sigma={self._sigma!r})' + + def __getstate__(self): + return self._mu, self._sigma + + def __setstate__(self, state): + self._mu, self._sigma = state + + +## kde_random() ############################################################## + +def _newton_raphson(f_inv_estimate, f, f_prime, tolerance=1e-12): + def f_inv(y): + "Return x such that f(x) ≈ y within the specified tolerance." + x = f_inv_estimate(y) + while abs(diff := f(x) - y) > tolerance: + x -= diff / f_prime(x) + return x + return f_inv + +def _quartic_invcdf_estimate(p): + sign, p = (1.0, p) if p <= 1/2 else (-1.0, 1.0 - p) + x = (2.0 * p) ** 0.4258865685331 - 1.0 + if p >= 0.004 < 0.499: + x += 0.026818732 * sin(7.101753784 * p + 2.73230839482953) + return x * sign + +_quartic_invcdf = _newton_raphson( + f_inv_estimate = _quartic_invcdf_estimate, + f = lambda t: 3/16 * t**5 - 5/8 * t**3 + 15/16 * t + 1/2, + f_prime = lambda t: 15/16 * (1.0 - t * t) ** 2) + +def _triweight_invcdf_estimate(p): + sign, p = (1.0, p) if p <= 1/2 else (-1.0, 1.0 - p) + x = (2.0 * p) ** 0.3400218741872791 - 1.0 + return x * sign + +_triweight_invcdf = _newton_raphson( + f_inv_estimate = _triweight_invcdf_estimate, + f = lambda t: 35/32 * (-1/7*t**7 + 3/5*t**5 - t**3 + t) + 1/2, + f_prime = lambda t: 35/32 * (1.0 - t * t) ** 3) + +_kernel_invcdfs = { + 'normal': NormalDist().inv_cdf, + 'logistic': lambda p: log(p / (1 - p)), + 'sigmoid': lambda p: log(tan(p * pi/2)), + 'rectangular': lambda p: 2*p - 1, + 'parabolic': lambda p: 2 * cos((acos(2*p-1) + pi) / 3), + 'quartic': _quartic_invcdf, + 'triweight': _triweight_invcdf, + 'triangular': lambda p: sqrt(2*p) - 1 if p < 1/2 else 1 - sqrt(2 - 2*p), + 'cosine': lambda p: 2 * asin(2*p - 1) / pi, +} +_kernel_invcdfs['gauss'] = _kernel_invcdfs['normal'] +_kernel_invcdfs['uniform'] = _kernel_invcdfs['rectangular'] +_kernel_invcdfs['epanechnikov'] = _kernel_invcdfs['parabolic'] +_kernel_invcdfs['biweight'] = _kernel_invcdfs['quartic'] + +def kde_random(data, h, kernel='normal', *, seed=None): + """Return a function that makes a random selection from the estimated + probability density function created by kde(data, h, kernel). + + Providing a *seed* allows reproducible selections within a single + thread. The seed may be an integer, float, str, or bytes. + + A StatisticsError will be raised if the *data* sequence is empty. + + Example: + + >>> data = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2] + >>> rand = kde_random(data, h=1.5, seed=8675309) + >>> new_selections = [rand() for i in range(10)] + >>> [round(x, 1) for x in new_selections] + [0.7, 6.2, 1.2, 6.9, 7.0, 1.8, 2.5, -0.5, -1.8, 5.6] + + """ + n = len(data) + if not n: + raise StatisticsError('Empty data sequence') + + if not isinstance(data[0], (int, float)): + raise TypeError('Data sequence must contain ints or floats') + + if h <= 0.0: + raise StatisticsError(f'Bandwidth h must be positive, not {h=!r}') + + kernel_invcdf = _kernel_invcdfs.get(kernel) + if kernel_invcdf is None: + raise StatisticsError(f'Unknown kernel name: {kernel!r}') + + prng = _random.Random(seed) + random = prng.random + choice = prng.choice + + def rand(): + return choice(data) + h * kernel_invcdf(random()) + + rand.__doc__ = f'Random KDE selection with {h=!r} and {kernel=!r}' + + return rand diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index 8fcbcf3540..22bd17908a 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -1,4 +1,4 @@ -"""Test suite for statistics module, including helper NumericTestCase and +x = """Test suite for statistics module, including helper NumericTestCase and approx_equal function. """ @@ -9,13 +9,14 @@ import copy import decimal import doctest +import itertools import math import pickle import random import sys import unittest from test import support -from test.support import import_helper +from test.support import import_helper, requires_IEEE_754 from decimal import Decimal from fractions import Fraction @@ -27,6 +28,12 @@ # === Helper functions and class === +# Test copied from Lib/test/test_math.py +# detect evidence of double-rounding: fsum is not always correctly +# rounded on machines that suffer from double rounding. +x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer +HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4) + def sign(x): """Return -1.0 for negatives, including -0.0, otherwise +1.0.""" return math.copysign(1, x) @@ -691,14 +698,6 @@ def test_check_all(self): 'missing name "%s" in __all__' % name) -class DocTests(unittest.TestCase): - @unittest.skipIf(sys.flags.optimize >= 2, - "Docstrings are omitted with -OO and above") - def test_doc_tests(self): - failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS) - self.assertGreater(tried, 0) - self.assertEqual(failed, 0) - class StatisticsErrorTest(unittest.TestCase): def test_has_exception(self): errmsg = ( @@ -1039,50 +1038,6 @@ def test_error_msg(self): self.assertEqual(errmsg, msg) -class FindLteqTest(unittest.TestCase): - # Test _find_lteq private function. - - def test_invalid_input_values(self): - for a, x in [ - ([], 1), - ([1, 2], 3), - ([1, 3], 2) - ]: - with self.subTest(a=a, x=x): - with self.assertRaises(ValueError): - statistics._find_lteq(a, x) - - def test_locate_successfully(self): - for a, x, expected_i in [ - ([1, 1, 1, 2, 3], 1, 0), - ([0, 1, 1, 1, 2, 3], 1, 1), - ([1, 2, 3, 3, 3], 3, 2) - ]: - with self.subTest(a=a, x=x): - self.assertEqual(expected_i, statistics._find_lteq(a, x)) - - -class FindRteqTest(unittest.TestCase): - # Test _find_rteq private function. - - def test_invalid_input_values(self): - for a, l, x in [ - ([1], 2, 1), - ([1, 3], 0, 2) - ]: - with self.assertRaises(ValueError): - statistics._find_rteq(a, l, x) - - def test_locate_successfully(self): - for a, l, x, expected_i in [ - ([1, 1, 1, 2, 3], 0, 1, 2), - ([0, 1, 1, 1, 2, 3], 0, 1, 3), - ([1, 2, 3, 3, 3], 0, 3, 4) - ]: - with self.subTest(a=a, l=l, x=x): - self.assertEqual(expected_i, statistics._find_rteq(a, l, x)) - - # === Tests for public functions === class UnivariateCommonMixin: @@ -1117,7 +1072,7 @@ def test_no_inplace_modifications(self): def test_order_doesnt_matter(self): # Test that the order of data points doesn't change the result. - # CAUTION: due to floating point rounding errors, the result actually + # CAUTION: due to floating-point rounding errors, the result actually # may depend on the order. Consider this test representing an ideal. # To avoid this test failing, only test with exact values such as ints # or Fractions. @@ -1210,6 +1165,9 @@ def __pow__(self, other): def __add__(self, other): return type(self)(super().__add__(other)) __radd__ = __add__ + def __mul__(self, other): + return type(self)(super().__mul__(other)) + __rmul__ = __mul__ return (float, Decimal, Fraction, MyFloat) def test_types_conserved(self): @@ -1782,6 +1740,12 @@ def test_repeated_single_value(self): data = [x]*count self.assertEqual(self.func(data), float(x)) + def test_single_value(self): + # Override method from AverageMixin. + # Average of a single value is the value as a float. + for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')): + self.assertEqual(self.func([x]), float(x)) + def test_odd_fractions(self): # Test median_grouped works with an odd number of Fractions. F = Fraction @@ -1961,6 +1925,27 @@ def test_special_values(self): with self.assertRaises(ValueError): fmean([Inf, -Inf]) + def test_weights(self): + fmean = statistics.fmean + StatisticsError = statistics.StatisticsError + self.assertEqual( + fmean([10, 10, 10, 50], [0.25] * 4), + fmean([10, 10, 10, 50])) + self.assertEqual( + fmean([10, 10, 20], [0.25, 0.25, 0.50]), + fmean([10, 10, 20, 20])) + self.assertEqual( # inputs are iterators + fmean(iter([10, 10, 20]), iter([0.25, 0.25, 0.50])), + fmean([10, 10, 20, 20])) + with self.assertRaises(StatisticsError): + fmean([10, 20, 30], [1, 2]) # unequal lengths + with self.assertRaises(StatisticsError): + fmean(iter([10, 20, 30]), iter([1, 2])) # unequal lengths + with self.assertRaises(StatisticsError): + fmean([10, 20], [-1, 1]) # sum of weights is zero + with self.assertRaises(StatisticsError): + fmean(iter([10, 20]), iter([-1, 1])) # sum of weights is zero + # === Tests for variances and standard deviations === @@ -2137,6 +2122,104 @@ def test_center_not_at_mean(self): self.assertEqual(self.func(data), 2.5) self.assertEqual(self.func(data, mu=0.5), 6.5) +class TestSqrtHelpers(unittest.TestCase): + + def test_integer_sqrt_of_frac_rto(self): + for n, m in itertools.product(range(100), range(1, 1000)): + r = statistics._integer_sqrt_of_frac_rto(n, m) + self.assertIsInstance(r, int) + if r*r*m == n: + # Root is exact + continue + # Inexact, so the root should be odd + self.assertEqual(r&1, 1) + # Verify correct rounding + self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2) + + @requires_IEEE_754 + @support.requires_resource('cpu') + def test_float_sqrt_of_frac(self): + + def is_root_correctly_rounded(x: Fraction, root: float) -> bool: + if not x: + return root == 0.0 + + # Extract adjacent representable floats + r_up: float = math.nextafter(root, math.inf) + r_down: float = math.nextafter(root, -math.inf) + assert r_down < root < r_up + + # Convert to fractions for exact arithmetic + frac_root: Fraction = Fraction(root) + half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2 + half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2 + + # Check a closed interval. + # Does not test for a midpoint rounding rule. + return half_way_down ** 2 <= x <= half_way_up ** 2 + + randrange = random.randrange + + for i in range(60_000): + numerator: int = randrange(10 ** randrange(50)) + denonimator: int = randrange(10 ** randrange(50)) + 1 + with self.subTest(numerator=numerator, denonimator=denonimator): + x: Fraction = Fraction(numerator, denonimator) + root: float = statistics._float_sqrt_of_frac(numerator, denonimator) + self.assertTrue(is_root_correctly_rounded(x, root)) + + # Verify that corner cases and error handling match math.sqrt() + self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0) + with self.assertRaises(ValueError): + statistics._float_sqrt_of_frac(-1, 1) + with self.assertRaises(ValueError): + statistics._float_sqrt_of_frac(1, -1) + + # Error handling for zero denominator matches that for Fraction(1, 0) + with self.assertRaises(ZeroDivisionError): + statistics._float_sqrt_of_frac(1, 0) + + # The result is well defined if both inputs are negative + self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1)) + + def test_decimal_sqrt_of_frac(self): + root: Decimal + numerator: int + denominator: int + + for root, numerator, denominator in [ + (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj + (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up + (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down + ]: + with decimal.localcontext(decimal.DefaultContext): + self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root) + + # Confirm expected root with a quad precision decimal computation + with decimal.localcontext(decimal.DefaultContext) as ctx: + ctx.prec *= 4 + high_prec_ratio = Decimal(numerator) / Decimal(denominator) + ctx.rounding = decimal.ROUND_05UP + high_prec_root = high_prec_ratio.sqrt() + with decimal.localcontext(decimal.DefaultContext): + target_root = +high_prec_root + self.assertEqual(root, target_root) + + # Verify that corner cases and error handling match Decimal.sqrt() + self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0) + with self.assertRaises(decimal.InvalidOperation): + statistics._decimal_sqrt_of_frac(-1, 1) + with self.assertRaises(decimal.InvalidOperation): + statistics._decimal_sqrt_of_frac(1, -1) + + # Error handling for zero denominator matches that for Fraction(1, 0) + with self.assertRaises(ZeroDivisionError): + statistics._decimal_sqrt_of_frac(1, 0) + + # The result is well defined if both inputs are negative + self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1)) + + class TestStdev(VarianceStdevMixin, NumericTestCase): # Tests for sample standard deviation. def setUp(self): @@ -2151,7 +2234,7 @@ def test_compare_to_variance(self): # Test that stdev is, in fact, the square root of variance. data = [random.uniform(-2, 9) for _ in range(1000)] expected = math.sqrt(statistics.variance(data)) - self.assertEqual(self.func(data), expected) + self.assertAlmostEqual(self.func(data), expected) def test_center_not_at_mean(self): data = (1.0, 2.0) @@ -2219,10 +2302,12 @@ def test_error_cases(self): StatisticsError = statistics.StatisticsError with self.assertRaises(StatisticsError): geometric_mean([]) # empty input - with self.assertRaises(StatisticsError): - geometric_mean([3.5, 0.0, 5.25]) # zero input with self.assertRaises(StatisticsError): geometric_mean([3.5, -4.0, 5.25]) # negative input + with self.assertRaises(StatisticsError): + geometric_mean([0.0, -4.0, 5.25]) # negative input with zero + with self.assertRaises(StatisticsError): + geometric_mean([3.5, -math.inf, 5.25]) # negative infinity with self.assertRaises(StatisticsError): geometric_mean(iter([])) # empty iterator with self.assertRaises(TypeError): @@ -2245,6 +2330,203 @@ def test_special_values(self): with self.assertRaises(ValueError): geometric_mean([Inf, -Inf]) + # Cases with zero + self.assertEqual(geometric_mean([3, 0.0, 5]), 0.0) # Any zero gives a zero + self.assertEqual(geometric_mean([3, -0.0, 5]), 0.0) # Negative zero allowed + self.assertTrue(math.isnan(geometric_mean([0, NaN]))) # NaN beats zero + self.assertTrue(math.isnan(geometric_mean([0, Inf]))) # Because 0.0 * Inf -> NaN + + def test_mixed_int_and_float(self): + # Regression test for b.p.o. issue #28327 + geometric_mean = statistics.geometric_mean + expected_mean = 3.80675409583932 + values = [ + [2, 3, 5, 7], + [2, 3, 5, 7.0], + [2, 3, 5.0, 7.0], + [2, 3.0, 5.0, 7.0], + [2.0, 3.0, 5.0, 7.0], + ] + for v in values: + with self.subTest(v=v): + actual_mean = geometric_mean(v) + self.assertAlmostEqual(actual_mean, expected_mean, places=5) + + +class TestKDE(unittest.TestCase): + + @support.requires_resource('cpu') + def test_kde(self): + kde = statistics.kde + StatisticsError = statistics.StatisticsError + + kernels = ['normal', 'gauss', 'logistic', 'sigmoid', 'rectangular', + 'uniform', 'triangular', 'parabolic', 'epanechnikov', + 'quartic', 'biweight', 'triweight', 'cosine'] + + sample = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2] + + # The approximate integral of a PDF should be close to 1.0 + + def integrate(func, low, high, steps=10_000): + "Numeric approximation of a definite function integral." + dx = (high - low) / steps + midpoints = (low + (i + 1/2) * dx for i in range(steps)) + return sum(map(func, midpoints)) * dx + + for kernel in kernels: + with self.subTest(kernel=kernel): + f_hat = kde(sample, h=1.5, kernel=kernel) + area = integrate(f_hat, -20, 20) + self.assertAlmostEqual(area, 1.0, places=4) + + # Check CDF against an integral of the PDF + + data = [3, 5, 10, 12] + h = 2.3 + x = 10.5 + for kernel in kernels: + with self.subTest(kernel=kernel): + cdf = kde(data, h, kernel, cumulative=True) + f_hat = kde(data, h, kernel) + area = integrate(f_hat, -20, x, 100_000) + self.assertAlmostEqual(cdf(x), area, places=4) + + # Check error cases + + with self.assertRaises(StatisticsError): + kde([], h=1.0) # Empty dataset + with self.assertRaises(TypeError): + kde(['abc', 'def'], 1.5) # Non-numeric data + with self.assertRaises(TypeError): + kde(iter(sample), 1.5) # Data is not a sequence + with self.assertRaises(StatisticsError): + kde(sample, h=0.0) # Zero bandwidth + with self.assertRaises(StatisticsError): + kde(sample, h=-1.0) # Negative bandwidth + with self.assertRaises(TypeError): + kde(sample, h='str') # Wrong bandwidth type + with self.assertRaises(StatisticsError): + kde(sample, h=1.0, kernel='bogus') # Invalid kernel + with self.assertRaises(TypeError): + kde(sample, 1.0, 'gauss', True) # Positional cumulative argument + + # Test name and docstring of the generated function + + h = 1.5 + kernel = 'cosine' + f_hat = kde(sample, h, kernel) + self.assertEqual(f_hat.__name__, 'pdf') + self.assertIn(kernel, f_hat.__doc__) + self.assertIn(repr(h), f_hat.__doc__) + + # Test closed interval for the support boundaries. + # In particular, 'uniform' should non-zero at the boundaries. + + f_hat = kde([0], 1.0, 'uniform') + self.assertEqual(f_hat(-1.0), 1/2) + self.assertEqual(f_hat(1.0), 1/2) + + # Test online updates to data + + data = [1, 2] + f_hat = kde(data, 5.0, 'triangular') + self.assertEqual(f_hat(100), 0.0) + data.append(100) + self.assertGreater(f_hat(100), 0.0) + + def test_kde_kernel_invcdfs(self): + kernel_invcdfs = statistics._kernel_invcdfs + kde = statistics.kde + + # Verify that cdf / invcdf will round trip + xarr = [i/100 for i in range(-100, 101)] + for kernel, invcdf in kernel_invcdfs.items(): + with self.subTest(kernel=kernel): + cdf = kde([0.0], h=1.0, kernel=kernel, cumulative=True) + for x in xarr: + self.assertAlmostEqual(invcdf(cdf(x)), x, places=5) + + @support.requires_resource('cpu') + def test_kde_random(self): + kde_random = statistics.kde_random + StatisticsError = statistics.StatisticsError + kernels = ['normal', 'gauss', 'logistic', 'sigmoid', 'rectangular', + 'uniform', 'triangular', 'parabolic', 'epanechnikov', + 'quartic', 'biweight', 'triweight', 'cosine'] + sample = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2] + + # Smoke test + + for kernel in kernels: + with self.subTest(kernel=kernel): + rand = kde_random(sample, h=1.5, kernel=kernel) + selections = [rand() for i in range(10)] + + # Check error cases + + with self.assertRaises(StatisticsError): + kde_random([], h=1.0) # Empty dataset + with self.assertRaises(TypeError): + kde_random(['abc', 'def'], 1.5) # Non-numeric data + with self.assertRaises(TypeError): + kde_random(iter(sample), 1.5) # Data is not a sequence + with self.assertRaises(StatisticsError): + kde_random(sample, h=-1.0) # Zero bandwidth + with self.assertRaises(StatisticsError): + kde_random(sample, h=0.0) # Negative bandwidth + with self.assertRaises(TypeError): + kde_random(sample, h='str') # Wrong bandwidth type + with self.assertRaises(StatisticsError): + kde_random(sample, h=1.0, kernel='bogus') # Invalid kernel + + # Test name and docstring of the generated function + + h = 1.5 + kernel = 'cosine' + rand = kde_random(sample, h, kernel) + self.assertEqual(rand.__name__, 'rand') + self.assertIn(kernel, rand.__doc__) + self.assertIn(repr(h), rand.__doc__) + + # Approximate distribution test: Compare a random sample to the expected distribution + + data = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2, 7.8, 14.3, 15.1, 15.3, 15.8, 17.0] + xarr = [x / 10 for x in range(-100, 250)] + # TODO: RUSTPYTHON - n originally 1_000_000 in CPython implementation but sorting is slow + n = 30_000 + h = 1.75 + dx = 0.1 + + def p_observed(x): + # P(x <= X < x+dx) + i = bisect.bisect_left(big_sample, x) + j = bisect.bisect_left(big_sample, x + dx) + return (j - i) / len(big_sample) + + def p_expected(x): + # P(x <= X < x+dx) + return F_hat(x + dx) - F_hat(x) + + for kernel in kernels: + with self.subTest(kernel=kernel): + + rand = kde_random(data, h, kernel, seed=8675309**2) + big_sample = sorted([rand() for i in range(n)]) + F_hat = statistics.kde(data, h, kernel, cumulative=True) + + for x in xarr: + # TODO : RUSTPYTHON - abs_tol=0.0005 in CPython implementation but smaller `n` increases variance + self.assertTrue(math.isclose(p_observed(x), p_expected(x), abs_tol=0.005)) + + # Test online updates to data + + data = [1, 2] + rand = kde_random(data, 5, 'triangular') + self.assertLess(max([rand() for i in range(5000)]), 10) + data.append(100) + self.assertGreater(max(rand() for i in range(5000)), 10) + class TestQuantiles(unittest.TestCase): @@ -2355,6 +2637,11 @@ def f(x): data = random.choices(range(100), k=k) q1, q2, q3 = quantiles(data, method='inclusive') self.assertEqual(q2, statistics.median(data)) + # Base case with a single data point: When estimating quantiles from + # a sample, we want to be able to add one sample point at a time, + # getting increasingly better estimates. + self.assertEqual(quantiles([10], n=4), [10.0, 10.0, 10.0]) + self.assertEqual(quantiles([10], n=4, method='exclusive'), [10.0, 10.0, 10.0]) def test_equal_inputs(self): quantiles = statistics.quantiles @@ -2405,7 +2692,7 @@ def test_error_cases(self): with self.assertRaises(ValueError): quantiles([10, 20, 30], method='X') # method is unknown with self.assertRaises(StatisticsError): - quantiles([10], n=4) # not enough data points + quantiles([], n=4) # not enough data points with self.assertRaises(TypeError): quantiles([10, None, 30], n=4) # data is non-numeric @@ -2464,6 +2751,95 @@ def test_different_scales(self): self.assertAlmostEqual(statistics.correlation(x, y), 1) self.assertAlmostEqual(statistics.covariance(x, y), 0.1) + def test_sqrtprod_helper_function_fundamentals(self): + # Verify that results are close to sqrt(x * y) + for i in range(100): + x = random.expovariate() + y = random.expovariate() + expected = math.sqrt(x * y) + actual = statistics._sqrtprod(x, y) + with self.subTest(x=x, y=y, expected=expected, actual=actual): + self.assertAlmostEqual(expected, actual) + + x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661 + self.assertEqual(statistics._sqrtprod(x, y), target) + self.assertNotEqual(math.sqrt(x * y), target) + + # Test that range extremes avoid underflow and overflow + smallest = sys.float_info.min * sys.float_info.epsilon + self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest) + biggest = sys.float_info.max + self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest) + + # Check special values and the sign of the result + special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0, + math.nan, -math.nan, math.inf, -math.inf] + for x, y in itertools.product(special_values, repeat=2): + try: + expected = math.sqrt(x * y) + except ValueError: + expected = 'ValueError' + try: + actual = statistics._sqrtprod(x, y) + except ValueError: + actual = 'ValueError' + with self.subTest(x=x, y=y, expected=expected, actual=actual): + if isinstance(expected, str) and expected == 'ValueError': + self.assertEqual(actual, 'ValueError') + continue + self.assertIsInstance(actual, float) + if math.isnan(expected): + self.assertTrue(math.isnan(actual)) + continue + self.assertEqual(actual, expected) + self.assertEqual(sign(actual), sign(expected)) + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Allow for a weaker sumprod() implmentation + def test_sqrtprod_helper_function_improved_accuracy(self): + # Test a known example where accuracy is improved + x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661 + self.assertEqual(statistics._sqrtprod(x, y), target) + self.assertNotEqual(math.sqrt(x * y), target) + + def reference_value(x: float, y: float) -> float: + x = decimal.Decimal(x) + y = decimal.Decimal(y) + with decimal.localcontext() as ctx: + ctx.prec = 200 + return float((x * y).sqrt()) + + # Verify that the new function with improved accuracy + # agrees with a reference value more often than old version. + new_agreements = 0 + old_agreements = 0 + for i in range(10_000): + x = random.expovariate() + y = random.expovariate() + new = statistics._sqrtprod(x, y) + old = math.sqrt(x * y) + ref = reference_value(x, y) + new_agreements += (new == ref) + old_agreements += (old == ref) + self.assertGreater(new_agreements, old_agreements) + + def test_correlation_spearman(self): + # https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php + # Compare with: + # >>> import scipy.stats.mstats + # >>> scipy.stats.mstats.spearmanr(reading, mathematics) + # SpearmanrResult(correlation=0.6686960980480712, pvalue=0.03450954165178532) + # And Wolfram Alpha gives: 0.668696 + # https://www.wolframalpha.com/input?i=SpearmanRho%5B%7B56%2C+75%2C+45%2C+71%2C+61%2C+64%2C+58%2C+80%2C+76%2C+61%7D%2C+%7B66%2C+70%2C+40%2C+60%2C+65%2C+56%2C+59%2C+77%2C+67%2C+63%7D%5D + reading = [56, 75, 45, 71, 61, 64, 58, 80, 76, 61] + mathematics = [66, 70, 40, 60, 65, 56, 59, 77, 67, 63] + self.assertAlmostEqual(statistics.correlation(reading, mathematics, method='ranked'), + 0.6686960980480712) + + with self.assertRaises(ValueError): + statistics.correlation(reading, mathematics, method='bad_method') class TestLinearRegression(unittest.TestCase): @@ -2487,6 +2863,22 @@ def test_results(self): self.assertAlmostEqual(intercept, true_intercept) self.assertAlmostEqual(slope, true_slope) + def test_proportional(self): + x = [10, 20, 30, 40] + y = [180, 398, 610, 799] + slope, intercept = statistics.linear_regression(x, y, proportional=True) + self.assertAlmostEqual(slope, 20 + 1/150) + self.assertEqual(intercept, 0.0) + + def test_float_output(self): + x = [Fraction(2, 3), Fraction(3, 4)] + y = [Fraction(4, 5), Fraction(5, 6)] + slope, intercept = statistics.linear_regression(x, y) + self.assertTrue(isinstance(slope, float)) + self.assertTrue(isinstance(intercept, float)) + slope, intercept = statistics.linear_regression(x, y, proportional=True) + self.assertTrue(isinstance(slope, float)) + self.assertTrue(isinstance(intercept, float)) class TestNormalDist: @@ -2640,6 +3032,7 @@ def test_cdf(self): self.assertTrue(math.isnan(X.cdf(float('NaN')))) @support.skip_if_pgo_task + @support.requires_resource('cpu') def test_inv_cdf(self): NormalDist = self.module.NormalDist @@ -2697,9 +3090,10 @@ def test_inv_cdf(self): iq.inv_cdf(1.0) # p is one with self.assertRaises(self.module.StatisticsError): iq.inv_cdf(1.1) # p over one - with self.assertRaises(self.module.StatisticsError): - iq = NormalDist(100, 0) # sigma is zero - iq.inv_cdf(0.5) + + # Supported case: + iq = NormalDist(100, 0) # sigma is zero + self.assertEqual(iq.inv_cdf(0.5), 100) # Special values self.assertTrue(math.isnan(Z.inv_cdf(float('NaN')))) @@ -2882,14 +3276,19 @@ def __init__(self, mu, sigma): nd = NormalDist(100, 15) self.assertNotEqual(nd, lnd) - def test_pickle_and_copy(self): + def test_copy(self): nd = self.module.NormalDist(37.5, 5.625) nd1 = copy.copy(nd) self.assertEqual(nd, nd1) nd2 = copy.deepcopy(nd) self.assertEqual(nd, nd2) - nd3 = pickle.loads(pickle.dumps(nd)) - self.assertEqual(nd, nd3) + + def test_pickle(self): + nd = self.module.NormalDist(37.5, 5.625) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = pickle.loads(pickle.dumps(nd, protocol=proto)) + self.assertEqual(nd, pickled) def test_hashability(self): ND = self.module.NormalDist @@ -2911,7 +3310,7 @@ def setUp(self): def tearDown(self): sys.modules['statistics'] = statistics - + @unittest.skipUnless(c_statistics, 'requires _statistics') class TestNormalDistC(unittest.TestCase, TestNormalDist): @@ -2928,6 +3327,8 @@ def tearDown(self): def load_tests(loader, tests, ignore): """Used for doctest/unittest integration.""" tests.addTests(doctest.DocTestSuite()) + if sys.float_repr_style == 'short': + tests.addTests(doctest.DocTestSuite(statistics)) return tests diff --git a/stdlib/src/statistics.rs b/stdlib/src/statistics.rs index 141493c125..9f5b294c00 100644 --- a/stdlib/src/statistics.rs +++ b/stdlib/src/statistics.rs @@ -7,7 +7,7 @@ mod _statistics { // See https://github.com/python/cpython/blob/6846d6712a0894f8e1a91716c11dd79f42864216/Modules/_statisticsmodule.c#L28-L120 #[allow(clippy::excessive_precision)] fn normal_dist_inv_cdf(p: f64, mu: f64, sigma: f64) -> Option { - if p <= 0.0 || p >= 1.0 || sigma <= 0.0 { + if p <= 0.0 || p >= 1.0 { return None; }