diff --git a/Cargo.lock b/Cargo.lock index 023c6240ce..02cd179ca0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -697,6 +697,12 @@ dependencies = [ "autocfg 1.0.0", ] +[[package]] +name = "instant" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7152d2aed88aa566e7a342250f21ba2222c1ae230ad577499dbfa3c18475b80" + [[package]] name = "is-macro" version = "0.1.8" @@ -826,6 +832,14 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "lock_api" +version = "0.3.4" +source = "git+https://github.com/Amanieu/parking_lot#ecaa94438e570c84f1a4c7db830916890f2ae44c" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.8" @@ -1076,6 +1090,30 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a86ed3f5f244b372d6b1a00b72ef7f8876d0bc6a78a4c9985c53614041512063" +[[package]] +name = "parking_lot" +version = "0.10.2" +source = "git+https://github.com/Amanieu/parking_lot#ecaa94438e570c84f1a4c7db830916890f2ae44c" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.7.2" +source = "git+https://github.com/Amanieu/parking_lot#ecaa94438e570c84f1a4c7db830916890f2ae44c" +dependencies = [ + "cfg-if", + "cloudabi", + "instant", + "libc", + "redox_syscall", + "smallvec", + "winapi", +] + [[package]] name = "paste" version = "0.1.10" @@ -1575,6 +1613,7 @@ dependencies = [ "flamer", "flate2", "foreign-types-shared", + "generational-arena", "gethostname", "getrandom", "hex", @@ -1602,6 +1641,7 @@ dependencies = [ "openssl", "openssl-probe", "openssl-sys", + "parking_lot", "paste", "pwd", "rand 0.7.3", @@ -1697,6 +1737,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + [[package]] name = "semver" version = "0.9.0" diff --git a/Lib/_rp_thread.py b/Lib/_rp_thread.py deleted file mode 100644 index b843c3ec97..0000000000 --- a/Lib/_rp_thread.py +++ /dev/null @@ -1,7 +0,0 @@ -import _thread -import _dummy_thread - -for k in _dummy_thread.__all__ + ['_set_sentinel', 'stack_size']: - if k not in _thread.__dict__: - # print('Populating _thread.%s' % k) - setattr(_thread, k, getattr(_dummy_thread, k)) diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py index 304c66f59b..7a84823622 100644 --- a/Lib/_weakrefset.py +++ b/Lib/_weakrefset.py @@ -194,3 +194,6 @@ def union(self, other): def isdisjoint(self, other): return len(self.intersection(other)) == 0 + + def __repr__(self): + return repr(self.data) diff --git a/Lib/importlib/_bootstrap.py b/Lib/importlib/_bootstrap.py index 32deef10af..cb35633493 100644 --- a/Lib/importlib/_bootstrap.py +++ b/Lib/importlib/_bootstrap.py @@ -1149,12 +1149,21 @@ def _setup(sys_module, _imp_module): # Directly load built-in modules needed during bootstrap. self_module = sys.modules[__name__] - for builtin_name in ('_thread', '_warnings', '_weakref'): + for builtin_name in ('_warnings', '_weakref'): if builtin_name not in sys.modules: builtin_module = _builtin_from_name(builtin_name) else: builtin_module = sys.modules[builtin_name] setattr(self_module, builtin_name, builtin_module) + # _thread was part of the above loop, but other parts of the code allow for it + # to be None, so we handle it separately here + builtin_name = '_thread' + if builtin_name in sys.modules: + builtin_module = sys.modules[builtin_name] + else: + builtin_spec = BuiltinImporter.find_spec(builtin_name) + builtin_module = builtin_spec and _load_unlocked(builtin_spec) + setattr(self_module, builtin_name, builtin_module) def _install(sys_module, _imp_module): diff --git a/Lib/logging/__init__.py b/Lib/logging/__init__.py index 89b5b886a5..436087dd72 100644 --- a/Lib/logging/__init__.py +++ b/Lib/logging/__init__.py @@ -37,8 +37,7 @@ 'warn', 'warning', 'getLogRecordFactory', 'setLogRecordFactory', 'lastResort', 'raiseExceptions'] -# TODO: import threading -import _thread +import threading __author__ = "Vinay Sajip " __status__ = "production" @@ -208,7 +207,7 @@ def _checkLevel(level): #the lock would already have been acquired - so we need an RLock. #The same argument applies to Loggers and Manager.loggerDict. # -_lock = _thread.RLock() +_lock = threading.RLock() def _acquireLock(): """ @@ -844,7 +843,7 @@ def createLock(self): """ Acquire a thread lock for serializing access to the underlying I/O. """ - self.lock = _thread.RLock() + self.lock = threading.RLock() _register_at_fork_acquire_release(self) def acquire(self): diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py new file mode 100644 index 0000000000..7b1ad8eb6d --- /dev/null +++ b/Lib/test/lock_tests.py @@ -0,0 +1,949 @@ +""" +Various tests for synchronization primitives. +""" + +import sys +import time +from _thread import start_new_thread, TIMEOUT_MAX +import threading +import unittest +import weakref + +from test import support + + +def _wait(): + # A crude wait/yield function not relying on synchronization primitives. + time.sleep(0.01) + +class Bunch(object): + """ + A bunch of threads. + """ + def __init__(self, f, n, wait_before_exit=False): + """ + Construct a bunch of `n` threads running the same function `f`. + If `wait_before_exit` is True, the threads won't terminate until + do_finish() is called. + """ + self.f = f + self.n = n + self.started = [] + self.finished = [] + self._can_exit = not wait_before_exit + self.wait_thread = support.wait_threads_exit() + self.wait_thread.__enter__() + + def task(): + tid = threading.get_ident() + self.started.append(tid) + try: + f() + finally: + self.finished.append(tid) + while not self._can_exit: + _wait() + + try: + for i in range(n): + start_new_thread(task, ()) + except: + self._can_exit = True + raise + + def wait_for_started(self): + while len(self.started) < self.n: + _wait() + + def wait_for_finished(self): + while len(self.finished) < self.n: + _wait() + # Wait for threads exit + self.wait_thread.__exit__(None, None, None) + + def do_finish(self): + self._can_exit = True + + +class BaseTestCase(unittest.TestCase): + def setUp(self): + self._threads = support.threading_setup() + + def tearDown(self): + support.threading_cleanup(*self._threads) + support.reap_children() + + def assertTimeout(self, actual, expected): + # The waiting and/or time.monotonic() can be imprecise, which + # is why comparing to the expected value would sometimes fail + # (especially under Windows). + self.assertGreaterEqual(actual, expected * 0.6) + # Test nothing insane happened + self.assertLess(actual, expected * 10.0) + + +class BaseLockTests(BaseTestCase): + """ + Tests for both recursive and non-recursive locks. + """ + + def test_constructor(self): + lock = self.locktype() + del lock + + def test_repr(self): + lock = self.locktype() + self.assertRegex(repr(lock), "") + del lock + + def test_locked_repr(self): + lock = self.locktype() + lock.acquire() + self.assertRegex(repr(lock), "") + del lock + + def test_acquire_destroy(self): + lock = self.locktype() + lock.acquire() + del lock + + def test_acquire_release(self): + lock = self.locktype() + lock.acquire() + lock.release() + del lock + + def test_try_acquire(self): + lock = self.locktype() + self.assertTrue(lock.acquire(False)) + lock.release() + + def test_try_acquire_contended(self): + lock = self.locktype() + lock.acquire() + result = [] + def f(): + result.append(lock.acquire(False)) + Bunch(f, 1).wait_for_finished() + self.assertFalse(result[0]) + lock.release() + + def test_acquire_contended(self): + lock = self.locktype() + lock.acquire() + N = 5 + def f(): + lock.acquire() + lock.release() + + b = Bunch(f, N) + b.wait_for_started() + _wait() + self.assertEqual(len(b.finished), 0) + lock.release() + b.wait_for_finished() + self.assertEqual(len(b.finished), N) + + def test_with(self): + lock = self.locktype() + def f(): + lock.acquire() + lock.release() + def _with(err=None): + with lock: + if err is not None: + raise err + _with() + # Check the lock is unacquired + Bunch(f, 1).wait_for_finished() + self.assertRaises(TypeError, _with, TypeError) + # Check the lock is unacquired + Bunch(f, 1).wait_for_finished() + + def test_thread_leak(self): + # The lock shouldn't leak a Thread instance when used from a foreign + # (non-threading) thread. + lock = self.locktype() + def f(): + lock.acquire() + lock.release() + n = len(threading.enumerate()) + # We run many threads in the hope that existing threads ids won't + # be recycled. + Bunch(f, 15).wait_for_finished() + if len(threading.enumerate()) != n: + # There is a small window during which a Thread instance's + # target function has finished running, but the Thread is still + # alive and registered. Avoid spurious failures by waiting a + # bit more (seen on a buildbot). + time.sleep(0.4) + self.assertEqual(n, len(threading.enumerate())) + + def test_timeout(self): + lock = self.locktype() + # Can't set timeout if not blocking + self.assertRaises(ValueError, lock.acquire, 0, 1) + # Invalid timeout values + self.assertRaises(ValueError, lock.acquire, timeout=-100) + self.assertRaises(OverflowError, lock.acquire, timeout=1e100) + self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1) + # TIMEOUT_MAX is ok + lock.acquire(timeout=TIMEOUT_MAX) + lock.release() + t1 = time.monotonic() + self.assertTrue(lock.acquire(timeout=5)) + t2 = time.monotonic() + # Just a sanity test that it didn't actually wait for the timeout. + self.assertLess(t2 - t1, 5) + results = [] + def f(): + t1 = time.monotonic() + results.append(lock.acquire(timeout=0.5)) + t2 = time.monotonic() + results.append(t2 - t1) + Bunch(f, 1).wait_for_finished() + self.assertFalse(results[0]) + self.assertTimeout(results[1], 0.5) + + def test_weakref_exists(self): + lock = self.locktype() + ref = weakref.ref(lock) + self.assertIsNotNone(ref()) + + def test_weakref_deleted(self): + lock = self.locktype() + ref = weakref.ref(lock) + del lock + self.assertIsNone(ref()) + + +class LockTests(BaseLockTests): + """ + Tests for non-recursive, weak locks + (which can be acquired and released from different threads). + """ + def test_reacquire(self): + # Lock needs to be released before re-acquiring. + lock = self.locktype() + phase = [] + + def f(): + lock.acquire() + phase.append(None) + lock.acquire() + phase.append(None) + + with support.wait_threads_exit(): + start_new_thread(f, ()) + while len(phase) == 0: + _wait() + _wait() + self.assertEqual(len(phase), 1) + lock.release() + while len(phase) == 1: + _wait() + self.assertEqual(len(phase), 2) + + def test_different_thread(self): + # Lock can be released from a different thread. + lock = self.locktype() + lock.acquire() + def f(): + lock.release() + b = Bunch(f, 1) + b.wait_for_finished() + lock.acquire() + lock.release() + + def test_state_after_timeout(self): + # Issue #11618: check that lock is in a proper state after a + # (non-zero) timeout. + lock = self.locktype() + lock.acquire() + self.assertFalse(lock.acquire(timeout=0.01)) + lock.release() + self.assertFalse(lock.locked()) + self.assertTrue(lock.acquire(blocking=False)) + + +class RLockTests(BaseLockTests): + """ + Tests for recursive locks. + """ + def test_reacquire(self): + lock = self.locktype() + lock.acquire() + lock.acquire() + lock.release() + lock.acquire() + lock.release() + lock.release() + + def test_release_unacquired(self): + # Cannot release an unacquired lock + lock = self.locktype() + self.assertRaises(RuntimeError, lock.release) + lock.acquire() + lock.acquire() + lock.release() + lock.acquire() + lock.release() + lock.release() + self.assertRaises(RuntimeError, lock.release) + + def test_release_save_unacquired(self): + # Cannot _release_save an unacquired lock + lock = self.locktype() + self.assertRaises(RuntimeError, lock._release_save) + lock.acquire() + lock.acquire() + lock.release() + lock.acquire() + lock.release() + lock.release() + self.assertRaises(RuntimeError, lock._release_save) + + def test_different_thread(self): + # Cannot release from a different thread + lock = self.locktype() + def f(): + lock.acquire() + b = Bunch(f, 1, True) + try: + self.assertRaises(RuntimeError, lock.release) + finally: + b.do_finish() + b.wait_for_finished() + + def test__is_owned(self): + lock = self.locktype() + self.assertFalse(lock._is_owned()) + lock.acquire() + self.assertTrue(lock._is_owned()) + lock.acquire() + self.assertTrue(lock._is_owned()) + result = [] + def f(): + result.append(lock._is_owned()) + Bunch(f, 1).wait_for_finished() + self.assertFalse(result[0]) + lock.release() + self.assertTrue(lock._is_owned()) + lock.release() + self.assertFalse(lock._is_owned()) + + +class EventTests(BaseTestCase): + """ + Tests for Event objects. + """ + + def test_is_set(self): + evt = self.eventtype() + self.assertFalse(evt.is_set()) + evt.set() + self.assertTrue(evt.is_set()) + evt.set() + self.assertTrue(evt.is_set()) + evt.clear() + self.assertFalse(evt.is_set()) + evt.clear() + self.assertFalse(evt.is_set()) + + def _check_notify(self, evt): + # All threads get notified + N = 5 + results1 = [] + results2 = [] + def f(): + results1.append(evt.wait()) + results2.append(evt.wait()) + b = Bunch(f, N) + b.wait_for_started() + _wait() + self.assertEqual(len(results1), 0) + evt.set() + b.wait_for_finished() + self.assertEqual(results1, [True] * N) + self.assertEqual(results2, [True] * N) + + def test_notify(self): + evt = self.eventtype() + self._check_notify(evt) + # Another time, after an explicit clear() + evt.set() + evt.clear() + self._check_notify(evt) + + def test_timeout(self): + evt = self.eventtype() + results1 = [] + results2 = [] + N = 5 + def f(): + results1.append(evt.wait(0.0)) + t1 = time.monotonic() + r = evt.wait(0.5) + t2 = time.monotonic() + results2.append((r, t2 - t1)) + Bunch(f, N).wait_for_finished() + self.assertEqual(results1, [False] * N) + for r, dt in results2: + self.assertFalse(r) + self.assertTimeout(dt, 0.5) + # The event is set + results1 = [] + results2 = [] + evt.set() + Bunch(f, N).wait_for_finished() + self.assertEqual(results1, [True] * N) + for r, dt in results2: + self.assertTrue(r) + + def test_set_and_clear(self): + # Issue #13502: check that wait() returns true even when the event is + # cleared before the waiting thread is woken up. + evt = self.eventtype() + results = [] + timeout = 0.250 + N = 5 + def f(): + results.append(evt.wait(timeout * 4)) + b = Bunch(f, N) + b.wait_for_started() + time.sleep(timeout) + evt.set() + evt.clear() + b.wait_for_finished() + self.assertEqual(results, [True] * N) + + def test_reset_internal_locks(self): + # ensure that condition is still using a Lock after reset + evt = self.eventtype() + with evt._cond: + self.assertFalse(evt._cond.acquire(False)) + evt._reset_internal_locks() + with evt._cond: + self.assertFalse(evt._cond.acquire(False)) + + +class ConditionTests(BaseTestCase): + """ + Tests for condition variables. + """ + + def test_acquire(self): + cond = self.condtype() + # Be default we have an RLock: the condition can be acquired multiple + # times. + cond.acquire() + cond.acquire() + cond.release() + cond.release() + lock = threading.Lock() + cond = self.condtype(lock) + cond.acquire() + self.assertFalse(lock.acquire(False)) + cond.release() + self.assertTrue(lock.acquire(False)) + self.assertFalse(cond.acquire(False)) + lock.release() + with cond: + self.assertFalse(lock.acquire(False)) + + def test_unacquired_wait(self): + cond = self.condtype() + self.assertRaises(RuntimeError, cond.wait) + + def test_unacquired_notify(self): + cond = self.condtype() + self.assertRaises(RuntimeError, cond.notify) + + def _check_notify(self, cond): + # Note that this test is sensitive to timing. If the worker threads + # don't execute in a timely fashion, the main thread may think they + # are further along then they are. The main thread therefore issues + # _wait() statements to try to make sure that it doesn't race ahead + # of the workers. + # Secondly, this test assumes that condition variables are not subject + # to spurious wakeups. The absence of spurious wakeups is an implementation + # detail of Condition Variables in current CPython, but in general, not + # a guaranteed property of condition variables as a programming + # construct. In particular, it is possible that this can no longer + # be conveniently guaranteed should their implementation ever change. + N = 5 + ready = [] + results1 = [] + results2 = [] + phase_num = 0 + def f(): + cond.acquire() + ready.append(phase_num) + result = cond.wait() + cond.release() + results1.append((result, phase_num)) + cond.acquire() + ready.append(phase_num) + result = cond.wait() + cond.release() + results2.append((result, phase_num)) + b = Bunch(f, N) + b.wait_for_started() + # first wait, to ensure all workers settle into cond.wait() before + # we continue. See issues #8799 and #30727. + while len(ready) < 5: + _wait() + ready.clear() + self.assertEqual(results1, []) + # Notify 3 threads at first + cond.acquire() + cond.notify(3) + _wait() + phase_num = 1 + cond.release() + while len(results1) < 3: + _wait() + self.assertEqual(results1, [(True, 1)] * 3) + self.assertEqual(results2, []) + # make sure all awaken workers settle into cond.wait() + while len(ready) < 3: + _wait() + # Notify 5 threads: they might be in their first or second wait + cond.acquire() + cond.notify(5) + _wait() + phase_num = 2 + cond.release() + while len(results1) + len(results2) < 8: + _wait() + self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2) + self.assertEqual(results2, [(True, 2)] * 3) + # make sure all workers settle into cond.wait() + while len(ready) < 5: + _wait() + # Notify all threads: they are all in their second wait + cond.acquire() + cond.notify_all() + _wait() + phase_num = 3 + cond.release() + while len(results2) < 5: + _wait() + self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2) + self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2) + b.wait_for_finished() + + def test_notify(self): + cond = self.condtype() + self._check_notify(cond) + # A second time, to check internal state is still ok. + self._check_notify(cond) + + def test_timeout(self): + cond = self.condtype() + results = [] + N = 5 + def f(): + cond.acquire() + t1 = time.monotonic() + result = cond.wait(0.5) + t2 = time.monotonic() + cond.release() + results.append((t2 - t1, result)) + Bunch(f, N).wait_for_finished() + self.assertEqual(len(results), N) + for dt, result in results: + self.assertTimeout(dt, 0.5) + # Note that conceptually (that"s the condition variable protocol) + # a wait() may succeed even if no one notifies us and before any + # timeout occurs. Spurious wakeups can occur. + # This makes it hard to verify the result value. + # In practice, this implementation has no spurious wakeups. + self.assertFalse(result) + + def test_waitfor(self): + cond = self.condtype() + state = 0 + def f(): + with cond: + result = cond.wait_for(lambda : state==4) + self.assertTrue(result) + self.assertEqual(state, 4) + b = Bunch(f, 1) + b.wait_for_started() + for i in range(4): + time.sleep(0.01) + with cond: + state += 1 + cond.notify() + b.wait_for_finished() + + def test_waitfor_timeout(self): + cond = self.condtype() + state = 0 + success = [] + def f(): + with cond: + dt = time.monotonic() + result = cond.wait_for(lambda : state==4, timeout=0.1) + dt = time.monotonic() - dt + self.assertFalse(result) + self.assertTimeout(dt, 0.1) + success.append(None) + b = Bunch(f, 1) + b.wait_for_started() + # Only increment 3 times, so state == 4 is never reached. + for i in range(3): + time.sleep(0.01) + with cond: + state += 1 + cond.notify() + b.wait_for_finished() + self.assertEqual(len(success), 1) + + +class BaseSemaphoreTests(BaseTestCase): + """ + Common tests for {bounded, unbounded} semaphore objects. + """ + + def test_constructor(self): + self.assertRaises(ValueError, self.semtype, value = -1) + self.assertRaises(ValueError, self.semtype, value = -sys.maxsize) + + def test_acquire(self): + sem = self.semtype(1) + sem.acquire() + sem.release() + sem = self.semtype(2) + sem.acquire() + sem.acquire() + sem.release() + sem.release() + + def test_acquire_destroy(self): + sem = self.semtype() + sem.acquire() + del sem + + def test_acquire_contended(self): + sem = self.semtype(7) + sem.acquire() + N = 10 + sem_results = [] + results1 = [] + results2 = [] + phase_num = 0 + def f(): + sem_results.append(sem.acquire()) + results1.append(phase_num) + sem_results.append(sem.acquire()) + results2.append(phase_num) + b = Bunch(f, 10) + b.wait_for_started() + while len(results1) + len(results2) < 6: + _wait() + self.assertEqual(results1 + results2, [0] * 6) + phase_num = 1 + for i in range(7): + sem.release() + while len(results1) + len(results2) < 13: + _wait() + self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7) + phase_num = 2 + for i in range(6): + sem.release() + while len(results1) + len(results2) < 19: + _wait() + self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6) + # The semaphore is still locked + self.assertFalse(sem.acquire(False)) + # Final release, to let the last thread finish + sem.release() + b.wait_for_finished() + self.assertEqual(sem_results, [True] * (6 + 7 + 6 + 1)) + + def test_try_acquire(self): + sem = self.semtype(2) + self.assertTrue(sem.acquire(False)) + self.assertTrue(sem.acquire(False)) + self.assertFalse(sem.acquire(False)) + sem.release() + self.assertTrue(sem.acquire(False)) + + def test_try_acquire_contended(self): + sem = self.semtype(4) + sem.acquire() + results = [] + def f(): + results.append(sem.acquire(False)) + results.append(sem.acquire(False)) + Bunch(f, 5).wait_for_finished() + # There can be a thread switch between acquiring the semaphore and + # appending the result, therefore results will not necessarily be + # ordered. + self.assertEqual(sorted(results), [False] * 7 + [True] * 3 ) + + def test_acquire_timeout(self): + sem = self.semtype(2) + self.assertRaises(ValueError, sem.acquire, False, timeout=1.0) + self.assertTrue(sem.acquire(timeout=0.005)) + self.assertTrue(sem.acquire(timeout=0.005)) + self.assertFalse(sem.acquire(timeout=0.005)) + sem.release() + self.assertTrue(sem.acquire(timeout=0.005)) + t = time.monotonic() + self.assertFalse(sem.acquire(timeout=0.5)) + dt = time.monotonic() - t + self.assertTimeout(dt, 0.5) + + def test_default_value(self): + # The default initial value is 1. + sem = self.semtype() + sem.acquire() + def f(): + sem.acquire() + sem.release() + b = Bunch(f, 1) + b.wait_for_started() + _wait() + self.assertFalse(b.finished) + sem.release() + b.wait_for_finished() + + def test_with(self): + sem = self.semtype(2) + def _with(err=None): + with sem: + self.assertTrue(sem.acquire(False)) + sem.release() + with sem: + self.assertFalse(sem.acquire(False)) + if err: + raise err + _with() + self.assertTrue(sem.acquire(False)) + sem.release() + self.assertRaises(TypeError, _with, TypeError) + self.assertTrue(sem.acquire(False)) + sem.release() + +class SemaphoreTests(BaseSemaphoreTests): + """ + Tests for unbounded semaphores. + """ + + def test_release_unacquired(self): + # Unbounded releases are allowed and increment the semaphore's value + sem = self.semtype(1) + sem.release() + sem.acquire() + sem.acquire() + sem.release() + + +class BoundedSemaphoreTests(BaseSemaphoreTests): + """ + Tests for bounded semaphores. + """ + + def test_release_unacquired(self): + # Cannot go past the initial value + sem = self.semtype() + self.assertRaises(ValueError, sem.release) + sem.acquire() + sem.release() + self.assertRaises(ValueError, sem.release) + + +class BarrierTests(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + defaultTimeout = 2.0 + + def setUp(self): + self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout) + def tearDown(self): + self.barrier.abort() + + def run_threads(self, f): + b = Bunch(f, self.N-1) + f() + b.wait_for_finished() + + def multipass(self, results, n): + m = self.barrier.parties + self.assertEqual(m, self.N) + for i in range(n): + results[0].append(True) + self.assertEqual(len(results[1]), i * m) + self.barrier.wait() + results[1].append(True) + self.assertEqual(len(results[0]), (i + 1) * m) + self.barrier.wait() + self.assertEqual(self.barrier.n_waiting, 0) + self.assertFalse(self.barrier.broken) + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [[],[]] + def f(): + self.multipass(results, passes) + self.run_threads(f) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + results = [] + def f(): + r = self.barrier.wait() + results.append(r) + + self.run_threads(f) + self.assertEqual(sum(results), sum(range(self.N))) + + def test_action(self): + """ + Test the 'action' callback + """ + results = [] + def action(): + results.append(True) + barrier = self.barriertype(self.N, action) + def f(): + barrier.wait() + self.assertEqual(len(results), 1) + + self.run_threads(f) + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = [] + results2 = [] + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = [] + results2 = [] + results3 = [] + def f(): + i = self.barrier.wait() + if i == self.N//2: + # Wait until the other threads are all in the barrier. + while self.barrier.n_waiting < self.N-1: + time.sleep(0.001) + self.barrier.reset() + else: + try: + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = [] + results2 = [] + results3 = [] + barrier2 = self.barriertype(self.N) + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == self.N//2: + self.barrier.reset() + barrier2.wait() + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + def test_timeout(self): + """ + Test wait(timeout) + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is late! + time.sleep(1.0) + # Default timeout is 2.0, so this is shorter. + self.assertRaises(threading.BrokenBarrierError, + self.barrier.wait, 0.5) + self.run_threads(f) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + # create a barrier with a low default timeout + barrier = self.barriertype(self.N, timeout=0.3) + def f(): + i = barrier.wait() + if i == self.N // 2: + # One thread is later than the default timeout of 0.3s. + time.sleep(1.0) + self.assertRaises(threading.BrokenBarrierError, barrier.wait) + self.run_threads(f) + + def test_single_thread(self): + b = self.barriertype(1) + b.wait() + b.wait() diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 2c6e0f84fa..2b0d8a9d79 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -681,6 +681,45 @@ def test_replace_overflow(self): self.checkraises(OverflowError, A2_16, "replace", "A", A2_16) self.checkraises(OverflowError, A2_16, "replace", "AA", A2_16+A2_16) + + # Python 3.9 + def test_removeprefix(self): + self.checkequal('am', 'spam', 'removeprefix', 'sp') + self.checkequal('spamspam', 'spamspamspam', 'removeprefix', 'spam') + self.checkequal('spam', 'spam', 'removeprefix', 'python') + self.checkequal('spam', 'spam', 'removeprefix', 'spider') + self.checkequal('spam', 'spam', 'removeprefix', 'spam and eggs') + + self.checkequal('', '', 'removeprefix', '') + self.checkequal('', '', 'removeprefix', 'abcde') + self.checkequal('abcde', 'abcde', 'removeprefix', '') + self.checkequal('', 'abcde', 'removeprefix', 'abcde') + + self.checkraises(TypeError, 'hello', 'removeprefix') + self.checkraises(TypeError, 'hello', 'removeprefix', 42) + self.checkraises(TypeError, 'hello', 'removeprefix', 42, 'h') + self.checkraises(TypeError, 'hello', 'removeprefix', 'h', 42) + self.checkraises(TypeError, 'hello', 'removeprefix', ("he", "l")) + + # Python 3.9 + def test_removesuffix(self): + self.checkequal('sp', 'spam', 'removesuffix', 'am') + self.checkequal('spamspam', 'spamspamspam', 'removesuffix', 'spam') + self.checkequal('spam', 'spam', 'removesuffix', 'python') + self.checkequal('spam', 'spam', 'removesuffix', 'blam') + self.checkequal('spam', 'spam', 'removesuffix', 'eggs and spam') + + self.checkequal('', '', 'removesuffix', '') + self.checkequal('', '', 'removesuffix', 'abcde') + self.checkequal('abcde', 'abcde', 'removesuffix', '') + self.checkequal('', 'abcde', 'removesuffix', 'abcde') + + self.checkraises(TypeError, 'hello', 'removesuffix') + self.checkraises(TypeError, 'hello', 'removesuffix', 42) + self.checkraises(TypeError, 'hello', 'removesuffix', 42, 'h') + self.checkraises(TypeError, 'hello', 'removesuffix', 'h', 42) + self.checkraises(TypeError, 'hello', 'removesuffix', ("lo", "l")) + def test_capitalize(self): self.checkequal(' hello ', ' hello ', 'capitalize') self.checkequal('Hello ', 'Hello ','capitalize') diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 4ea5c2be91..ee644b1269 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -3,11 +3,11 @@ if __name__ != 'test.support': raise ImportError('support must be imported from the test package') -# import asyncio.events +import asyncio.events import collections.abc import contextlib import errno -# import faulthandler +import faulthandler import fnmatch import functools # import gc @@ -16,7 +16,7 @@ import importlib import importlib.util import locale -# import logging.handlers +import logging.handlers # import nntplib import os import platform @@ -28,13 +28,13 @@ import subprocess import sys import sysconfig -# import tempfile +import tempfile import _thread -# import threading +import threading import time import types import unittest -# import urllib.error +import urllib.error import warnings from .testresult import get_test_runner diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py new file mode 100644 index 0000000000..e849c7ba49 --- /dev/null +++ b/Lib/test/test_argparse.py @@ -0,0 +1,5161 @@ +# Author: Steven J. Bethard . + +import codecs +import inspect +import os +import shutil +import stat +import sys +import textwrap +import tempfile +import unittest +import argparse + +from io import StringIO + +from test import support +from unittest import mock +class StdIOBuffer(StringIO): + pass + +class TestCase(unittest.TestCase): + + def setUp(self): + # The tests assume that line wrapping occurs at 80 columns, but this + # behaviour can be overridden by setting the COLUMNS environment + # variable. To ensure that this width is used, set COLUMNS to 80. + env = support.EnvironmentVarGuard() + env['COLUMNS'] = '80' + self.addCleanup(env.__exit__) + + +class TempDirMixin(object): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.old_dir = os.getcwd() + os.chdir(self.temp_dir) + + def tearDown(self): + os.chdir(self.old_dir) + for root, dirs, files in os.walk(self.temp_dir, topdown=False): + for name in files: + os.chmod(os.path.join(self.temp_dir, name), stat.S_IWRITE) + shutil.rmtree(self.temp_dir, True) + + def create_readonly_file(self, filename): + file_path = os.path.join(self.temp_dir, filename) + with open(file_path, 'w') as file: + file.write(filename) + os.chmod(file_path, stat.S_IREAD) + +class Sig(object): + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +class NS(object): + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __repr__(self): + sorted_items = sorted(self.__dict__.items()) + kwarg_str = ', '.join(['%s=%r' % tup for tup in sorted_items]) + return '%s(%s)' % (type(self).__name__, kwarg_str) + + def __eq__(self, other): + return vars(self) == vars(other) + + +class ArgumentParserError(Exception): + + def __init__(self, message, stdout=None, stderr=None, error_code=None): + Exception.__init__(self, message, stdout, stderr) + self.message = message + self.stdout = stdout + self.stderr = stderr + self.error_code = error_code + + +def stderr_to_parser_error(parse_args, *args, **kwargs): + # if this is being called recursively and stderr or stdout is already being + # redirected, simply call the function and let the enclosing function + # catch the exception + if isinstance(sys.stderr, StdIOBuffer) or isinstance(sys.stdout, StdIOBuffer): + return parse_args(*args, **kwargs) + + # if this is not being called recursively, redirect stderr and + # use it as the ArgumentParserError message + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = StdIOBuffer() + sys.stderr = StdIOBuffer() + try: + try: + result = parse_args(*args, **kwargs) + for key in list(vars(result)): + if getattr(result, key) is sys.stdout: + setattr(result, key, old_stdout) + if getattr(result, key) is sys.stderr: + setattr(result, key, old_stderr) + return result + except SystemExit: + code = sys.exc_info()[1].code + stdout = sys.stdout.getvalue() + stderr = sys.stderr.getvalue() + raise ArgumentParserError("SystemExit", stdout, stderr, code) + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + +class ErrorRaisingArgumentParser(argparse.ArgumentParser): + + def parse_args(self, *args, **kwargs): + parse_args = super(ErrorRaisingArgumentParser, self).parse_args + return stderr_to_parser_error(parse_args, *args, **kwargs) + + def exit(self, *args, **kwargs): + exit = super(ErrorRaisingArgumentParser, self).exit + return stderr_to_parser_error(exit, *args, **kwargs) + + def error(self, *args, **kwargs): + error = super(ErrorRaisingArgumentParser, self).error + return stderr_to_parser_error(error, *args, **kwargs) + + +class ParserTesterMetaclass(type): + """Adds parser tests using the class attributes. + + Classes of this type should specify the following attributes: + + argument_signatures -- a list of Sig objects which specify + the signatures of Argument objects to be created + failures -- a list of args lists that should cause the parser + to fail + successes -- a list of (initial_args, options, remaining_args) tuples + where initial_args specifies the string args to be parsed, + options is a dict that should match the vars() of the options + parsed out of initial_args, and remaining_args should be any + remaining unparsed arguments + """ + + def __init__(cls, name, bases, bodydict): + if name == 'ParserTestCase': + return + + # default parser signature is empty + if not hasattr(cls, 'parser_signature'): + cls.parser_signature = Sig() + if not hasattr(cls, 'parser_class'): + cls.parser_class = ErrorRaisingArgumentParser + + # --------------------------------------- + # functions for adding optional arguments + # --------------------------------------- + def no_groups(parser, argument_signatures): + """Add all arguments directly to the parser""" + for sig in argument_signatures: + parser.add_argument(*sig.args, **sig.kwargs) + + def one_group(parser, argument_signatures): + """Add all arguments under a single group in the parser""" + group = parser.add_argument_group('foo') + for sig in argument_signatures: + group.add_argument(*sig.args, **sig.kwargs) + + def many_groups(parser, argument_signatures): + """Add each argument in its own group to the parser""" + for i, sig in enumerate(argument_signatures): + group = parser.add_argument_group('foo:%i' % i) + group.add_argument(*sig.args, **sig.kwargs) + + # -------------------------- + # functions for parsing args + # -------------------------- + def listargs(parser, args): + """Parse the args by passing in a list""" + return parser.parse_args(args) + + def sysargs(parser, args): + """Parse the args by defaulting to sys.argv""" + old_sys_argv = sys.argv + sys.argv = [old_sys_argv[0]] + args + try: + return parser.parse_args() + finally: + sys.argv = old_sys_argv + + # class that holds the combination of one optional argument + # addition method and one arg parsing method + class AddTests(object): + + def __init__(self, tester_cls, add_arguments, parse_args): + self._add_arguments = add_arguments + self._parse_args = parse_args + + add_arguments_name = self._add_arguments.__name__ + parse_args_name = self._parse_args.__name__ + for test_func in [self.test_failures, self.test_successes]: + func_name = test_func.__name__ + names = func_name, add_arguments_name, parse_args_name + test_name = '_'.join(names) + + def wrapper(self, test_func=test_func): + test_func(self) + try: + wrapper.__name__ = test_name + except TypeError: + pass + setattr(tester_cls, test_name, wrapper) + + def _get_parser(self, tester): + args = tester.parser_signature.args + kwargs = tester.parser_signature.kwargs + parser = tester.parser_class(*args, **kwargs) + self._add_arguments(parser, tester.argument_signatures) + return parser + + def test_failures(self, tester): + parser = self._get_parser(tester) + for args_str in tester.failures: + args = args_str.split() + with tester.assertRaises(ArgumentParserError, msg=args): + parser.parse_args(args) + + def test_successes(self, tester): + parser = self._get_parser(tester) + for args, expected_ns in tester.successes: + if isinstance(args, str): + args = args.split() + result_ns = self._parse_args(parser, args) + tester.assertEqual(expected_ns, result_ns) + + # add tests for each combination of an optionals adding method + # and an arg parsing method + for add_arguments in [no_groups, one_group, many_groups]: + for parse_args in [listargs, sysargs]: + AddTests(cls, add_arguments, parse_args) + +bases = TestCase, +ParserTestCase = ParserTesterMetaclass('ParserTestCase', bases, {}) + +# =============== +# Optionals tests +# =============== + +class TestOptionalsSingleDash(ParserTestCase): + """Test an Optional with a single-dash option string""" + + argument_signatures = [Sig('-x')] + failures = ['-x', 'a', '--foo', '-x --foo', '-x -y'] + successes = [ + ('', NS(x=None)), + ('-x a', NS(x='a')), + ('-xa', NS(x='a')), + ('-x -1', NS(x='-1')), + ('-x-1', NS(x='-1')), + ] + + +class TestOptionalsSingleDashCombined(ParserTestCase): + """Test an Optional with a single-dash option string""" + + argument_signatures = [ + Sig('-x', action='store_true'), + Sig('-yyy', action='store_const', const=42), + Sig('-z'), + ] + failures = ['a', '--foo', '-xa', '-x --foo', '-x -z', '-z -x', + '-yx', '-yz a', '-yyyx', '-yyyza', '-xyza'] + successes = [ + ('', NS(x=False, yyy=None, z=None)), + ('-x', NS(x=True, yyy=None, z=None)), + ('-za', NS(x=False, yyy=None, z='a')), + ('-z a', NS(x=False, yyy=None, z='a')), + ('-xza', NS(x=True, yyy=None, z='a')), + ('-xz a', NS(x=True, yyy=None, z='a')), + ('-x -za', NS(x=True, yyy=None, z='a')), + ('-x -z a', NS(x=True, yyy=None, z='a')), + ('-y', NS(x=False, yyy=42, z=None)), + ('-yyy', NS(x=False, yyy=42, z=None)), + ('-x -yyy -za', NS(x=True, yyy=42, z='a')), + ('-x -yyy -z a', NS(x=True, yyy=42, z='a')), + ] + + +class TestOptionalsSingleDashLong(ParserTestCase): + """Test an Optional with a multi-character single-dash option string""" + + argument_signatures = [Sig('-foo')] + failures = ['-foo', 'a', '--foo', '-foo --foo', '-foo -y', '-fooa'] + successes = [ + ('', NS(foo=None)), + ('-foo a', NS(foo='a')), + ('-foo -1', NS(foo='-1')), + ('-fo a', NS(foo='a')), + ('-f a', NS(foo='a')), + ] + + +class TestOptionalsSingleDashSubsetAmbiguous(ParserTestCase): + """Test Optionals where option strings are subsets of each other""" + + argument_signatures = [Sig('-f'), Sig('-foobar'), Sig('-foorab')] + failures = ['-f', '-foo', '-fo', '-foo b', '-foob', '-fooba', '-foora'] + successes = [ + ('', NS(f=None, foobar=None, foorab=None)), + ('-f a', NS(f='a', foobar=None, foorab=None)), + ('-fa', NS(f='a', foobar=None, foorab=None)), + ('-foa', NS(f='oa', foobar=None, foorab=None)), + ('-fooa', NS(f='ooa', foobar=None, foorab=None)), + ('-foobar a', NS(f=None, foobar='a', foorab=None)), + ('-foorab a', NS(f=None, foobar=None, foorab='a')), + ] + + +class TestOptionalsSingleDashAmbiguous(ParserTestCase): + """Test Optionals that partially match but are not subsets""" + + argument_signatures = [Sig('-foobar'), Sig('-foorab')] + failures = ['-f', '-f a', '-fa', '-foa', '-foo', '-fo', '-foo b'] + successes = [ + ('', NS(foobar=None, foorab=None)), + ('-foob a', NS(foobar='a', foorab=None)), + ('-foor a', NS(foobar=None, foorab='a')), + ('-fooba a', NS(foobar='a', foorab=None)), + ('-foora a', NS(foobar=None, foorab='a')), + ('-foobar a', NS(foobar='a', foorab=None)), + ('-foorab a', NS(foobar=None, foorab='a')), + ] + + +class TestOptionalsNumeric(ParserTestCase): + """Test an Optional with a short opt string""" + + argument_signatures = [Sig('-1', dest='one')] + failures = ['-1', 'a', '-1 --foo', '-1 -y', '-1 -1', '-1 -2'] + successes = [ + ('', NS(one=None)), + ('-1 a', NS(one='a')), + ('-1a', NS(one='a')), + ('-1-2', NS(one='-2')), + ] + + +class TestOptionalsDoubleDash(ParserTestCase): + """Test an Optional with a double-dash option string""" + + argument_signatures = [Sig('--foo')] + failures = ['--foo', '-f', '-f a', 'a', '--foo -x', '--foo --bar'] + successes = [ + ('', NS(foo=None)), + ('--foo a', NS(foo='a')), + ('--foo=a', NS(foo='a')), + ('--foo -2.5', NS(foo='-2.5')), + ('--foo=-2.5', NS(foo='-2.5')), + ] + + +class TestOptionalsDoubleDashPartialMatch(ParserTestCase): + """Tests partial matching with a double-dash option string""" + + argument_signatures = [ + Sig('--badger', action='store_true'), + Sig('--bat'), + ] + failures = ['--bar', '--b', '--ba', '--b=2', '--ba=4', '--badge 5'] + successes = [ + ('', NS(badger=False, bat=None)), + ('--bat X', NS(badger=False, bat='X')), + ('--bad', NS(badger=True, bat=None)), + ('--badg', NS(badger=True, bat=None)), + ('--badge', NS(badger=True, bat=None)), + ('--badger', NS(badger=True, bat=None)), + ] + + +class TestOptionalsDoubleDashPrefixMatch(ParserTestCase): + """Tests when one double-dash option string is a prefix of another""" + + argument_signatures = [ + Sig('--badger', action='store_true'), + Sig('--ba'), + ] + failures = ['--bar', '--b', '--ba', '--b=2', '--badge 5'] + successes = [ + ('', NS(badger=False, ba=None)), + ('--ba X', NS(badger=False, ba='X')), + ('--ba=X', NS(badger=False, ba='X')), + ('--bad', NS(badger=True, ba=None)), + ('--badg', NS(badger=True, ba=None)), + ('--badge', NS(badger=True, ba=None)), + ('--badger', NS(badger=True, ba=None)), + ] + + +class TestOptionalsSingleDoubleDash(ParserTestCase): + """Test an Optional with single- and double-dash option strings""" + + argument_signatures = [ + Sig('-f', action='store_true'), + Sig('--bar'), + Sig('-baz', action='store_const', const=42), + ] + failures = ['--bar', '-fbar', '-fbaz', '-bazf', '-b B', 'B'] + successes = [ + ('', NS(f=False, bar=None, baz=None)), + ('-f', NS(f=True, bar=None, baz=None)), + ('--ba B', NS(f=False, bar='B', baz=None)), + ('-f --bar B', NS(f=True, bar='B', baz=None)), + ('-f -b', NS(f=True, bar=None, baz=42)), + ('-ba -f', NS(f=True, bar=None, baz=42)), + ] + + +class TestOptionalsAlternatePrefixChars(ParserTestCase): + """Test an Optional with option strings with custom prefixes""" + + parser_signature = Sig(prefix_chars='+:/', add_help=False) + argument_signatures = [ + Sig('+f', action='store_true'), + Sig('::bar'), + Sig('/baz', action='store_const', const=42), + ] + failures = ['--bar', '-fbar', '-b B', 'B', '-f', '--bar B', '-baz', '-h', '--help', '+h', '::help', '/help'] + successes = [ + ('', NS(f=False, bar=None, baz=None)), + ('+f', NS(f=True, bar=None, baz=None)), + ('::ba B', NS(f=False, bar='B', baz=None)), + ('+f ::bar B', NS(f=True, bar='B', baz=None)), + ('+f /b', NS(f=True, bar=None, baz=42)), + ('/ba +f', NS(f=True, bar=None, baz=42)), + ] + + +class TestOptionalsAlternatePrefixCharsAddedHelp(ParserTestCase): + """When ``-`` not in prefix_chars, default operators created for help + should use the prefix_chars in use rather than - or -- + http://bugs.python.org/issue9444""" + + parser_signature = Sig(prefix_chars='+:/', add_help=True) + argument_signatures = [ + Sig('+f', action='store_true'), + Sig('::bar'), + Sig('/baz', action='store_const', const=42), + ] + failures = ['--bar', '-fbar', '-b B', 'B', '-f', '--bar B', '-baz'] + successes = [ + ('', NS(f=False, bar=None, baz=None)), + ('+f', NS(f=True, bar=None, baz=None)), + ('::ba B', NS(f=False, bar='B', baz=None)), + ('+f ::bar B', NS(f=True, bar='B', baz=None)), + ('+f /b', NS(f=True, bar=None, baz=42)), + ('/ba +f', NS(f=True, bar=None, baz=42)) + ] + + +class TestOptionalsAlternatePrefixCharsMultipleShortArgs(ParserTestCase): + """Verify that Optionals must be called with their defined prefixes""" + + parser_signature = Sig(prefix_chars='+-', add_help=False) + argument_signatures = [ + Sig('-x', action='store_true'), + Sig('+y', action='store_true'), + Sig('+z', action='store_true'), + ] + failures = ['-w', + '-xyz', + '+x', + '-y', + '+xyz', + ] + successes = [ + ('', NS(x=False, y=False, z=False)), + ('-x', NS(x=True, y=False, z=False)), + ('+y -x', NS(x=True, y=True, z=False)), + ('+yz -x', NS(x=True, y=True, z=True)), + ] + + +class TestOptionalsShortLong(ParserTestCase): + """Test a combination of single- and double-dash option strings""" + + argument_signatures = [ + Sig('-v', '--verbose', '-n', '--noisy', action='store_true'), + ] + failures = ['--x --verbose', '-N', 'a', '-v x'] + successes = [ + ('', NS(verbose=False)), + ('-v', NS(verbose=True)), + ('--verbose', NS(verbose=True)), + ('-n', NS(verbose=True)), + ('--noisy', NS(verbose=True)), + ] + + +class TestOptionalsDest(ParserTestCase): + """Tests various means of setting destination""" + + argument_signatures = [Sig('--foo-bar'), Sig('--baz', dest='zabbaz')] + failures = ['a'] + successes = [ + ('--foo-bar f', NS(foo_bar='f', zabbaz=None)), + ('--baz g', NS(foo_bar=None, zabbaz='g')), + ('--foo-bar h --baz i', NS(foo_bar='h', zabbaz='i')), + ('--baz j --foo-bar k', NS(foo_bar='k', zabbaz='j')), + ] + + +class TestOptionalsDefault(ParserTestCase): + """Tests specifying a default for an Optional""" + + argument_signatures = [Sig('-x'), Sig('-y', default=42)] + failures = ['a'] + successes = [ + ('', NS(x=None, y=42)), + ('-xx', NS(x='x', y=42)), + ('-yy', NS(x=None, y='y')), + ] + + +class TestOptionalsNargsDefault(ParserTestCase): + """Tests not specifying the number of args for an Optional""" + + argument_signatures = [Sig('-x')] + failures = ['a', '-x'] + successes = [ + ('', NS(x=None)), + ('-x a', NS(x='a')), + ] + + +class TestOptionalsNargs1(ParserTestCase): + """Tests specifying 1 arg for an Optional""" + + argument_signatures = [Sig('-x', nargs=1)] + failures = ['a', '-x'] + successes = [ + ('', NS(x=None)), + ('-x a', NS(x=['a'])), + ] + + +class TestOptionalsNargs3(ParserTestCase): + """Tests specifying 3 args for an Optional""" + + argument_signatures = [Sig('-x', nargs=3)] + failures = ['a', '-x', '-x a', '-x a b', 'a -x', 'a -x b'] + successes = [ + ('', NS(x=None)), + ('-x a b c', NS(x=['a', 'b', 'c'])), + ] + + +class TestOptionalsNargsOptional(ParserTestCase): + """Tests specifying an Optional arg for an Optional""" + + argument_signatures = [ + Sig('-w', nargs='?'), + Sig('-x', nargs='?', const=42), + Sig('-y', nargs='?', default='spam'), + Sig('-z', nargs='?', type=int, const='42', default='84'), + ] + failures = ['2'] + successes = [ + ('', NS(w=None, x=None, y='spam', z=84)), + ('-w', NS(w=None, x=None, y='spam', z=84)), + ('-w 2', NS(w='2', x=None, y='spam', z=84)), + ('-x', NS(w=None, x=42, y='spam', z=84)), + ('-x 2', NS(w=None, x='2', y='spam', z=84)), + ('-y', NS(w=None, x=None, y=None, z=84)), + ('-y 2', NS(w=None, x=None, y='2', z=84)), + ('-z', NS(w=None, x=None, y='spam', z=42)), + ('-z 2', NS(w=None, x=None, y='spam', z=2)), + ] + + +class TestOptionalsNargsZeroOrMore(ParserTestCase): + """Tests specifying args for an Optional that accepts zero or more""" + + argument_signatures = [ + Sig('-x', nargs='*'), + Sig('-y', nargs='*', default='spam'), + ] + failures = ['a'] + successes = [ + ('', NS(x=None, y='spam')), + ('-x', NS(x=[], y='spam')), + ('-x a', NS(x=['a'], y='spam')), + ('-x a b', NS(x=['a', 'b'], y='spam')), + ('-y', NS(x=None, y=[])), + ('-y a', NS(x=None, y=['a'])), + ('-y a b', NS(x=None, y=['a', 'b'])), + ] + + +class TestOptionalsNargsOneOrMore(ParserTestCase): + """Tests specifying args for an Optional that accepts one or more""" + + argument_signatures = [ + Sig('-x', nargs='+'), + Sig('-y', nargs='+', default='spam'), + ] + failures = ['a', '-x', '-y', 'a -x', 'a -y b'] + successes = [ + ('', NS(x=None, y='spam')), + ('-x a', NS(x=['a'], y='spam')), + ('-x a b', NS(x=['a', 'b'], y='spam')), + ('-y a', NS(x=None, y=['a'])), + ('-y a b', NS(x=None, y=['a', 'b'])), + ] + + +class TestOptionalsChoices(ParserTestCase): + """Tests specifying the choices for an Optional""" + + argument_signatures = [ + Sig('-f', choices='abc'), + Sig('-g', type=int, choices=range(5))] + failures = ['a', '-f d', '-fad', '-ga', '-g 6'] + successes = [ + ('', NS(f=None, g=None)), + ('-f a', NS(f='a', g=None)), + ('-f c', NS(f='c', g=None)), + ('-g 0', NS(f=None, g=0)), + ('-g 03', NS(f=None, g=3)), + ('-fb -g4', NS(f='b', g=4)), + ] + + +class TestOptionalsRequired(ParserTestCase): + """Tests an optional action that is required""" + + argument_signatures = [ + Sig('-x', type=int, required=True), + ] + failures = ['a', ''] + successes = [ + ('-x 1', NS(x=1)), + ('-x42', NS(x=42)), + ] + + +class TestOptionalsActionStore(ParserTestCase): + """Tests the store action for an Optional""" + + argument_signatures = [Sig('-x', action='store')] + failures = ['a', 'a -x'] + successes = [ + ('', NS(x=None)), + ('-xfoo', NS(x='foo')), + ] + + +class TestOptionalsActionStoreConst(ParserTestCase): + """Tests the store_const action for an Optional""" + + argument_signatures = [Sig('-y', action='store_const', const=object)] + failures = ['a'] + successes = [ + ('', NS(y=None)), + ('-y', NS(y=object)), + ] + + +class TestOptionalsActionStoreFalse(ParserTestCase): + """Tests the store_false action for an Optional""" + + argument_signatures = [Sig('-z', action='store_false')] + failures = ['a', '-za', '-z a'] + successes = [ + ('', NS(z=True)), + ('-z', NS(z=False)), + ] + + +class TestOptionalsActionStoreTrue(ParserTestCase): + """Tests the store_true action for an Optional""" + + argument_signatures = [Sig('--apple', action='store_true')] + failures = ['a', '--apple=b', '--apple b'] + successes = [ + ('', NS(apple=False)), + ('--apple', NS(apple=True)), + ] + + +class TestOptionalsActionAppend(ParserTestCase): + """Tests the append action for an Optional""" + + argument_signatures = [Sig('--baz', action='append')] + failures = ['a', '--baz', 'a --baz', '--baz a b'] + successes = [ + ('', NS(baz=None)), + ('--baz a', NS(baz=['a'])), + ('--baz a --baz b', NS(baz=['a', 'b'])), + ] + + +class TestOptionalsActionAppendWithDefault(ParserTestCase): + """Tests the append action for an Optional""" + + argument_signatures = [Sig('--baz', action='append', default=['X'])] + failures = ['a', '--baz', 'a --baz', '--baz a b'] + successes = [ + ('', NS(baz=['X'])), + ('--baz a', NS(baz=['X', 'a'])), + ('--baz a --baz b', NS(baz=['X', 'a', 'b'])), + ] + + +class TestOptionalsActionAppendConst(ParserTestCase): + """Tests the append_const action for an Optional""" + + argument_signatures = [ + Sig('-b', action='append_const', const=Exception), + Sig('-c', action='append', dest='b'), + ] + failures = ['a', '-c', 'a -c', '-bx', '-b x'] + successes = [ + ('', NS(b=None)), + ('-b', NS(b=[Exception])), + ('-b -cx -b -cyz', NS(b=[Exception, 'x', Exception, 'yz'])), + ] + + +class TestOptionalsActionAppendConstWithDefault(ParserTestCase): + """Tests the append_const action for an Optional""" + + argument_signatures = [ + Sig('-b', action='append_const', const=Exception, default=['X']), + Sig('-c', action='append', dest='b'), + ] + failures = ['a', '-c', 'a -c', '-bx', '-b x'] + successes = [ + ('', NS(b=['X'])), + ('-b', NS(b=['X', Exception])), + ('-b -cx -b -cyz', NS(b=['X', Exception, 'x', Exception, 'yz'])), + ] + + +class TestOptionalsActionCount(ParserTestCase): + """Tests the count action for an Optional""" + + argument_signatures = [Sig('-x', action='count')] + failures = ['a', '-x a', '-x b', '-x a -x b'] + successes = [ + ('', NS(x=None)), + ('-x', NS(x=1)), + ] + + +class TestOptionalsAllowLongAbbreviation(ParserTestCase): + """Allow long options to be abbreviated unambiguously""" + + argument_signatures = [ + Sig('--foo'), + Sig('--foobaz'), + Sig('--fooble', action='store_true'), + ] + failures = ['--foob 5', '--foob'] + successes = [ + ('', NS(foo=None, foobaz=None, fooble=False)), + ('--foo 7', NS(foo='7', foobaz=None, fooble=False)), + ('--fooba a', NS(foo=None, foobaz='a', fooble=False)), + ('--foobl --foo g', NS(foo='g', foobaz=None, fooble=True)), + ] + + +class TestOptionalsDisallowLongAbbreviation(ParserTestCase): + """Do not allow abbreviations of long options at all""" + + parser_signature = Sig(allow_abbrev=False) + argument_signatures = [ + Sig('--foo'), + Sig('--foodle', action='store_true'), + Sig('--foonly'), + ] + failures = ['-foon 3', '--foon 3', '--food', '--food --foo 2'] + successes = [ + ('', NS(foo=None, foodle=False, foonly=None)), + ('--foo 3', NS(foo='3', foodle=False, foonly=None)), + ('--foonly 7 --foodle --foo 2', NS(foo='2', foodle=True, foonly='7')), + ] + +# ================ +# Positional tests +# ================ + +class TestPositionalsNargsNone(ParserTestCase): + """Test a Positional that doesn't specify nargs""" + + argument_signatures = [Sig('foo')] + failures = ['', '-x', 'a b'] + successes = [ + ('a', NS(foo='a')), + ] + + +class TestPositionalsNargs1(ParserTestCase): + """Test a Positional that specifies an nargs of 1""" + + argument_signatures = [Sig('foo', nargs=1)] + failures = ['', '-x', 'a b'] + successes = [ + ('a', NS(foo=['a'])), + ] + + +class TestPositionalsNargs2(ParserTestCase): + """Test a Positional that specifies an nargs of 2""" + + argument_signatures = [Sig('foo', nargs=2)] + failures = ['', 'a', '-x', 'a b c'] + successes = [ + ('a b', NS(foo=['a', 'b'])), + ] + + +class TestPositionalsNargsZeroOrMore(ParserTestCase): + """Test a Positional that specifies unlimited nargs""" + + argument_signatures = [Sig('foo', nargs='*')] + failures = ['-x'] + successes = [ + ('', NS(foo=[])), + ('a', NS(foo=['a'])), + ('a b', NS(foo=['a', 'b'])), + ] + + +class TestPositionalsNargsZeroOrMoreDefault(ParserTestCase): + """Test a Positional that specifies unlimited nargs and a default""" + + argument_signatures = [Sig('foo', nargs='*', default='bar')] + failures = ['-x'] + successes = [ + ('', NS(foo='bar')), + ('a', NS(foo=['a'])), + ('a b', NS(foo=['a', 'b'])), + ] + + +class TestPositionalsNargsOneOrMore(ParserTestCase): + """Test a Positional that specifies one or more nargs""" + + argument_signatures = [Sig('foo', nargs='+')] + failures = ['', '-x'] + successes = [ + ('a', NS(foo=['a'])), + ('a b', NS(foo=['a', 'b'])), + ] + + +class TestPositionalsNargsOptional(ParserTestCase): + """Tests an Optional Positional""" + + argument_signatures = [Sig('foo', nargs='?')] + failures = ['-x', 'a b'] + successes = [ + ('', NS(foo=None)), + ('a', NS(foo='a')), + ] + + +class TestPositionalsNargsOptionalDefault(ParserTestCase): + """Tests an Optional Positional with a default value""" + + argument_signatures = [Sig('foo', nargs='?', default=42)] + failures = ['-x', 'a b'] + successes = [ + ('', NS(foo=42)), + ('a', NS(foo='a')), + ] + + +class TestPositionalsNargsOptionalConvertedDefault(ParserTestCase): + """Tests an Optional Positional with a default value + that needs to be converted to the appropriate type. + """ + + argument_signatures = [ + Sig('foo', nargs='?', type=int, default='42'), + ] + failures = ['-x', 'a b', '1 2'] + successes = [ + ('', NS(foo=42)), + ('1', NS(foo=1)), + ] + + +class TestPositionalsNargsNoneNone(ParserTestCase): + """Test two Positionals that don't specify nargs""" + + argument_signatures = [Sig('foo'), Sig('bar')] + failures = ['', '-x', 'a', 'a b c'] + successes = [ + ('a b', NS(foo='a', bar='b')), + ] + + +class TestPositionalsNargsNone1(ParserTestCase): + """Test a Positional with no nargs followed by one with 1""" + + argument_signatures = [Sig('foo'), Sig('bar', nargs=1)] + failures = ['', '--foo', 'a', 'a b c'] + successes = [ + ('a b', NS(foo='a', bar=['b'])), + ] + + +class TestPositionalsNargs2None(ParserTestCase): + """Test a Positional with 2 nargs followed by one with none""" + + argument_signatures = [Sig('foo', nargs=2), Sig('bar')] + failures = ['', '--foo', 'a', 'a b', 'a b c d'] + successes = [ + ('a b c', NS(foo=['a', 'b'], bar='c')), + ] + + +class TestPositionalsNargsNoneZeroOrMore(ParserTestCase): + """Test a Positional with no nargs followed by one with unlimited""" + + argument_signatures = [Sig('foo'), Sig('bar', nargs='*')] + failures = ['', '--foo'] + successes = [ + ('a', NS(foo='a', bar=[])), + ('a b', NS(foo='a', bar=['b'])), + ('a b c', NS(foo='a', bar=['b', 'c'])), + ] + + +class TestPositionalsNargsNoneOneOrMore(ParserTestCase): + """Test a Positional with no nargs followed by one with one or more""" + + argument_signatures = [Sig('foo'), Sig('bar', nargs='+')] + failures = ['', '--foo', 'a'] + successes = [ + ('a b', NS(foo='a', bar=['b'])), + ('a b c', NS(foo='a', bar=['b', 'c'])), + ] + + +class TestPositionalsNargsNoneOptional(ParserTestCase): + """Test a Positional with no nargs followed by one with an Optional""" + + argument_signatures = [Sig('foo'), Sig('bar', nargs='?')] + failures = ['', '--foo', 'a b c'] + successes = [ + ('a', NS(foo='a', bar=None)), + ('a b', NS(foo='a', bar='b')), + ] + + +class TestPositionalsNargsZeroOrMoreNone(ParserTestCase): + """Test a Positional with unlimited nargs followed by one with none""" + + argument_signatures = [Sig('foo', nargs='*'), Sig('bar')] + failures = ['', '--foo'] + successes = [ + ('a', NS(foo=[], bar='a')), + ('a b', NS(foo=['a'], bar='b')), + ('a b c', NS(foo=['a', 'b'], bar='c')), + ] + + +class TestPositionalsNargsOneOrMoreNone(ParserTestCase): + """Test a Positional with one or more nargs followed by one with none""" + + argument_signatures = [Sig('foo', nargs='+'), Sig('bar')] + failures = ['', '--foo', 'a'] + successes = [ + ('a b', NS(foo=['a'], bar='b')), + ('a b c', NS(foo=['a', 'b'], bar='c')), + ] + + +class TestPositionalsNargsOptionalNone(ParserTestCase): + """Test a Positional with an Optional nargs followed by one with none""" + + argument_signatures = [Sig('foo', nargs='?', default=42), Sig('bar')] + failures = ['', '--foo', 'a b c'] + successes = [ + ('a', NS(foo=42, bar='a')), + ('a b', NS(foo='a', bar='b')), + ] + + +class TestPositionalsNargs2ZeroOrMore(ParserTestCase): + """Test a Positional with 2 nargs followed by one with unlimited""" + + argument_signatures = [Sig('foo', nargs=2), Sig('bar', nargs='*')] + failures = ['', '--foo', 'a'] + successes = [ + ('a b', NS(foo=['a', 'b'], bar=[])), + ('a b c', NS(foo=['a', 'b'], bar=['c'])), + ] + + +class TestPositionalsNargs2OneOrMore(ParserTestCase): + """Test a Positional with 2 nargs followed by one with one or more""" + + argument_signatures = [Sig('foo', nargs=2), Sig('bar', nargs='+')] + failures = ['', '--foo', 'a', 'a b'] + successes = [ + ('a b c', NS(foo=['a', 'b'], bar=['c'])), + ] + + +class TestPositionalsNargs2Optional(ParserTestCase): + """Test a Positional with 2 nargs followed by one optional""" + + argument_signatures = [Sig('foo', nargs=2), Sig('bar', nargs='?')] + failures = ['', '--foo', 'a', 'a b c d'] + successes = [ + ('a b', NS(foo=['a', 'b'], bar=None)), + ('a b c', NS(foo=['a', 'b'], bar='c')), + ] + + +class TestPositionalsNargsZeroOrMore1(ParserTestCase): + """Test a Positional with unlimited nargs followed by one with 1""" + + argument_signatures = [Sig('foo', nargs='*'), Sig('bar', nargs=1)] + failures = ['', '--foo', ] + successes = [ + ('a', NS(foo=[], bar=['a'])), + ('a b', NS(foo=['a'], bar=['b'])), + ('a b c', NS(foo=['a', 'b'], bar=['c'])), + ] + + +class TestPositionalsNargsOneOrMore1(ParserTestCase): + """Test a Positional with one or more nargs followed by one with 1""" + + argument_signatures = [Sig('foo', nargs='+'), Sig('bar', nargs=1)] + failures = ['', '--foo', 'a'] + successes = [ + ('a b', NS(foo=['a'], bar=['b'])), + ('a b c', NS(foo=['a', 'b'], bar=['c'])), + ] + + +class TestPositionalsNargsOptional1(ParserTestCase): + """Test a Positional with an Optional nargs followed by one with 1""" + + argument_signatures = [Sig('foo', nargs='?'), Sig('bar', nargs=1)] + failures = ['', '--foo', 'a b c'] + successes = [ + ('a', NS(foo=None, bar=['a'])), + ('a b', NS(foo='a', bar=['b'])), + ] + + +class TestPositionalsNargsNoneZeroOrMore1(ParserTestCase): + """Test three Positionals: no nargs, unlimited nargs and 1 nargs""" + + argument_signatures = [ + Sig('foo'), + Sig('bar', nargs='*'), + Sig('baz', nargs=1), + ] + failures = ['', '--foo', 'a'] + successes = [ + ('a b', NS(foo='a', bar=[], baz=['b'])), + ('a b c', NS(foo='a', bar=['b'], baz=['c'])), + ] + + +class TestPositionalsNargsNoneOneOrMore1(ParserTestCase): + """Test three Positionals: no nargs, one or more nargs and 1 nargs""" + + argument_signatures = [ + Sig('foo'), + Sig('bar', nargs='+'), + Sig('baz', nargs=1), + ] + failures = ['', '--foo', 'a', 'b'] + successes = [ + ('a b c', NS(foo='a', bar=['b'], baz=['c'])), + ('a b c d', NS(foo='a', bar=['b', 'c'], baz=['d'])), + ] + + +class TestPositionalsNargsNoneOptional1(ParserTestCase): + """Test three Positionals: no nargs, optional narg and 1 nargs""" + + argument_signatures = [ + Sig('foo'), + Sig('bar', nargs='?', default=0.625), + Sig('baz', nargs=1), + ] + failures = ['', '--foo', 'a'] + successes = [ + ('a b', NS(foo='a', bar=0.625, baz=['b'])), + ('a b c', NS(foo='a', bar='b', baz=['c'])), + ] + + +class TestPositionalsNargsOptionalOptional(ParserTestCase): + """Test two optional nargs""" + + argument_signatures = [ + Sig('foo', nargs='?'), + Sig('bar', nargs='?', default=42), + ] + failures = ['--foo', 'a b c'] + successes = [ + ('', NS(foo=None, bar=42)), + ('a', NS(foo='a', bar=42)), + ('a b', NS(foo='a', bar='b')), + ] + + +class TestPositionalsNargsOptionalZeroOrMore(ParserTestCase): + """Test an Optional narg followed by unlimited nargs""" + + argument_signatures = [Sig('foo', nargs='?'), Sig('bar', nargs='*')] + failures = ['--foo'] + successes = [ + ('', NS(foo=None, bar=[])), + ('a', NS(foo='a', bar=[])), + ('a b', NS(foo='a', bar=['b'])), + ('a b c', NS(foo='a', bar=['b', 'c'])), + ] + + +class TestPositionalsNargsOptionalOneOrMore(ParserTestCase): + """Test an Optional narg followed by one or more nargs""" + + argument_signatures = [Sig('foo', nargs='?'), Sig('bar', nargs='+')] + failures = ['', '--foo'] + successes = [ + ('a', NS(foo=None, bar=['a'])), + ('a b', NS(foo='a', bar=['b'])), + ('a b c', NS(foo='a', bar=['b', 'c'])), + ] + + +class TestPositionalsChoicesString(ParserTestCase): + """Test a set of single-character choices""" + + argument_signatures = [Sig('spam', choices=set('abcdefg'))] + failures = ['', '--foo', 'h', '42', 'ef'] + successes = [ + ('a', NS(spam='a')), + ('g', NS(spam='g')), + ] + + +class TestPositionalsChoicesInt(ParserTestCase): + """Test a set of integer choices""" + + argument_signatures = [Sig('spam', type=int, choices=range(20))] + failures = ['', '--foo', 'h', '42', 'ef'] + successes = [ + ('4', NS(spam=4)), + ('15', NS(spam=15)), + ] + + +class TestPositionalsActionAppend(ParserTestCase): + """Test the 'append' action""" + + argument_signatures = [ + Sig('spam', action='append'), + Sig('spam', action='append', nargs=2), + ] + failures = ['', '--foo', 'a', 'a b', 'a b c d'] + successes = [ + ('a b c', NS(spam=['a', ['b', 'c']])), + ] + +# ======================================== +# Combined optionals and positionals tests +# ======================================== + +class TestOptionalsNumericAndPositionals(ParserTestCase): + """Tests negative number args when numeric options are present""" + + argument_signatures = [ + Sig('x', nargs='?'), + Sig('-4', dest='y', action='store_true'), + ] + failures = ['-2', '-315'] + successes = [ + ('', NS(x=None, y=False)), + ('a', NS(x='a', y=False)), + ('-4', NS(x=None, y=True)), + ('-4 a', NS(x='a', y=True)), + ] + + +class TestOptionalsAlmostNumericAndPositionals(ParserTestCase): + """Tests negative number args when almost numeric options are present""" + + argument_signatures = [ + Sig('x', nargs='?'), + Sig('-k4', dest='y', action='store_true'), + ] + failures = ['-k3'] + successes = [ + ('', NS(x=None, y=False)), + ('-2', NS(x='-2', y=False)), + ('a', NS(x='a', y=False)), + ('-k4', NS(x=None, y=True)), + ('-k4 a', NS(x='a', y=True)), + ] + + +class TestEmptyAndSpaceContainingArguments(ParserTestCase): + + argument_signatures = [ + Sig('x', nargs='?'), + Sig('-y', '--yyy', dest='y'), + ] + failures = ['-y'] + successes = [ + ([''], NS(x='', y=None)), + (['a badger'], NS(x='a badger', y=None)), + (['-a badger'], NS(x='-a badger', y=None)), + (['-y', ''], NS(x=None, y='')), + (['-y', 'a badger'], NS(x=None, y='a badger')), + (['-y', '-a badger'], NS(x=None, y='-a badger')), + (['--yyy=a badger'], NS(x=None, y='a badger')), + (['--yyy=-a badger'], NS(x=None, y='-a badger')), + ] + + +class TestPrefixCharacterOnlyArguments(ParserTestCase): + + parser_signature = Sig(prefix_chars='-+') + argument_signatures = [ + Sig('-', dest='x', nargs='?', const='badger'), + Sig('+', dest='y', type=int, default=42), + Sig('-+-', dest='z', action='store_true'), + ] + failures = ['-y', '+ -'] + successes = [ + ('', NS(x=None, y=42, z=False)), + ('-', NS(x='badger', y=42, z=False)), + ('- X', NS(x='X', y=42, z=False)), + ('+ -3', NS(x=None, y=-3, z=False)), + ('-+-', NS(x=None, y=42, z=True)), + ('- ===', NS(x='===', y=42, z=False)), + ] + + +class TestNargsZeroOrMore(ParserTestCase): + """Tests specifying args for an Optional that accepts zero or more""" + + argument_signatures = [Sig('-x', nargs='*'), Sig('y', nargs='*')] + failures = [] + successes = [ + ('', NS(x=None, y=[])), + ('-x', NS(x=[], y=[])), + ('-x a', NS(x=['a'], y=[])), + ('-x a -- b', NS(x=['a'], y=['b'])), + ('a', NS(x=None, y=['a'])), + ('a -x', NS(x=[], y=['a'])), + ('a -x b', NS(x=['b'], y=['a'])), + ] + + +class TestNargsRemainder(ParserTestCase): + """Tests specifying a positional with nargs=REMAINDER""" + + argument_signatures = [Sig('x'), Sig('y', nargs='...'), Sig('-z')] + failures = ['', '-z', '-z Z'] + successes = [ + ('X', NS(x='X', y=[], z=None)), + ('-z Z X', NS(x='X', y=[], z='Z')), + ('X A B -z Z', NS(x='X', y=['A', 'B', '-z', 'Z'], z=None)), + ('X Y --foo', NS(x='X', y=['Y', '--foo'], z=None)), + ] + + +class TestOptionLike(ParserTestCase): + """Tests options that may or may not be arguments""" + + argument_signatures = [ + Sig('-x', type=float), + Sig('-3', type=float, dest='y'), + Sig('z', nargs='*'), + ] + failures = ['-x', '-y2.5', '-xa', '-x -a', + '-x -3', '-x -3.5', '-3 -3.5', + '-x -2.5', '-x -2.5 a', '-3 -.5', + 'a x -1', '-x -1 a', '-3 -1 a'] + successes = [ + ('', NS(x=None, y=None, z=[])), + ('-x 2.5', NS(x=2.5, y=None, z=[])), + ('-x 2.5 a', NS(x=2.5, y=None, z=['a'])), + ('-3.5', NS(x=None, y=0.5, z=[])), + ('-3-.5', NS(x=None, y=-0.5, z=[])), + ('-3 .5', NS(x=None, y=0.5, z=[])), + ('a -3.5', NS(x=None, y=0.5, z=['a'])), + ('a', NS(x=None, y=None, z=['a'])), + ('a -x 1', NS(x=1.0, y=None, z=['a'])), + ('-x 1 a', NS(x=1.0, y=None, z=['a'])), + ('-3 1 a', NS(x=None, y=1.0, z=['a'])), + ] + + +class TestDefaultSuppress(ParserTestCase): + """Test actions with suppressed defaults""" + + argument_signatures = [ + Sig('foo', nargs='?', default=argparse.SUPPRESS), + Sig('bar', nargs='*', default=argparse.SUPPRESS), + Sig('--baz', action='store_true', default=argparse.SUPPRESS), + ] + failures = ['-x'] + successes = [ + ('', NS()), + ('a', NS(foo='a')), + ('a b', NS(foo='a', bar=['b'])), + ('--baz', NS(baz=True)), + ('a --baz', NS(foo='a', baz=True)), + ('--baz a b', NS(foo='a', bar=['b'], baz=True)), + ] + + +class TestParserDefaultSuppress(ParserTestCase): + """Test actions with a parser-level default of SUPPRESS""" + + parser_signature = Sig(argument_default=argparse.SUPPRESS) + argument_signatures = [ + Sig('foo', nargs='?'), + Sig('bar', nargs='*'), + Sig('--baz', action='store_true'), + ] + failures = ['-x'] + successes = [ + ('', NS()), + ('a', NS(foo='a')), + ('a b', NS(foo='a', bar=['b'])), + ('--baz', NS(baz=True)), + ('a --baz', NS(foo='a', baz=True)), + ('--baz a b', NS(foo='a', bar=['b'], baz=True)), + ] + + +class TestParserDefault42(ParserTestCase): + """Test actions with a parser-level default of 42""" + + parser_signature = Sig(argument_default=42) + argument_signatures = [ + Sig('--version', action='version', version='1.0'), + Sig('foo', nargs='?'), + Sig('bar', nargs='*'), + Sig('--baz', action='store_true'), + ] + failures = ['-x'] + successes = [ + ('', NS(foo=42, bar=42, baz=42, version=42)), + ('a', NS(foo='a', bar=42, baz=42, version=42)), + ('a b', NS(foo='a', bar=['b'], baz=42, version=42)), + ('--baz', NS(foo=42, bar=42, baz=True, version=42)), + ('a --baz', NS(foo='a', bar=42, baz=True, version=42)), + ('--baz a b', NS(foo='a', bar=['b'], baz=True, version=42)), + ] + + +class TestArgumentsFromFile(TempDirMixin, ParserTestCase): + """Test reading arguments from a file""" + + def setUp(self): + super(TestArgumentsFromFile, self).setUp() + file_texts = [ + ('hello', 'hello world!\n'), + ('recursive', '-a\n' + 'A\n' + '@hello'), + ('invalid', '@no-such-path\n'), + ] + for path, text in file_texts: + with open(path, 'w') as file: + file.write(text) + + parser_signature = Sig(fromfile_prefix_chars='@') + argument_signatures = [ + Sig('-a'), + Sig('x'), + Sig('y', nargs='+'), + ] + failures = ['', '-b', 'X', '@invalid', '@missing'] + successes = [ + ('X Y', NS(a=None, x='X', y=['Y'])), + ('X -a A Y Z', NS(a='A', x='X', y=['Y', 'Z'])), + ('@hello X', NS(a=None, x='hello world!', y=['X'])), + ('X @hello', NS(a=None, x='X', y=['hello world!'])), + ('-a B @recursive Y Z', NS(a='A', x='hello world!', y=['Y', 'Z'])), + ('X @recursive Z -a B', NS(a='B', x='X', y=['hello world!', 'Z'])), + (["-a", "", "X", "Y"], NS(a='', x='X', y=['Y'])), + ] + + +class TestArgumentsFromFileConverter(TempDirMixin, ParserTestCase): + """Test reading arguments from a file""" + + def setUp(self): + super(TestArgumentsFromFileConverter, self).setUp() + file_texts = [ + ('hello', 'hello world!\n'), + ] + for path, text in file_texts: + with open(path, 'w') as file: + file.write(text) + + class FromFileConverterArgumentParser(ErrorRaisingArgumentParser): + + def convert_arg_line_to_args(self, arg_line): + for arg in arg_line.split(): + if not arg.strip(): + continue + yield arg + parser_class = FromFileConverterArgumentParser + parser_signature = Sig(fromfile_prefix_chars='@') + argument_signatures = [ + Sig('y', nargs='+'), + ] + failures = [] + successes = [ + ('@hello X', NS(y=['hello', 'world!', 'X'])), + ] + + +# ===================== +# Type conversion tests +# ===================== + +class TestFileTypeRepr(TestCase): + + def test_r(self): + type = argparse.FileType('r') + self.assertEqual("FileType('r')", repr(type)) + + def test_wb_1(self): + type = argparse.FileType('wb', 1) + self.assertEqual("FileType('wb', 1)", repr(type)) + + def test_r_latin(self): + type = argparse.FileType('r', encoding='latin_1') + self.assertEqual("FileType('r', encoding='latin_1')", repr(type)) + + def test_w_big5_ignore(self): + type = argparse.FileType('w', encoding='big5', errors='ignore') + self.assertEqual("FileType('w', encoding='big5', errors='ignore')", + repr(type)) + + def test_r_1_replace(self): + type = argparse.FileType('r', 1, errors='replace') + self.assertEqual("FileType('r', 1, errors='replace')", repr(type)) + +class StdStreamComparer: + def __init__(self, attr): + self.attr = attr + + def __eq__(self, other): + return other == getattr(sys, self.attr) + +eq_stdin = StdStreamComparer('stdin') +eq_stdout = StdStreamComparer('stdout') +eq_stderr = StdStreamComparer('stderr') + +class RFile(object): + seen = {} + + def __init__(self, name): + self.name = name + + def __eq__(self, other): + if other in self.seen: + text = self.seen[other] + else: + text = self.seen[other] = other.read() + other.close() + if not isinstance(text, str): + text = text.decode('ascii') + return self.name == other.name == text + + +class TestFileTypeR(TempDirMixin, ParserTestCase): + """Test the FileType option/argument type for reading files""" + + def setUp(self): + super(TestFileTypeR, self).setUp() + for file_name in ['foo', 'bar']: + with open(os.path.join(self.temp_dir, file_name), 'w') as file: + file.write(file_name) + self.create_readonly_file('readonly') + + argument_signatures = [ + Sig('-x', type=argparse.FileType()), + Sig('spam', type=argparse.FileType('r')), + ] + failures = ['-x', '', 'non-existent-file.txt'] + successes = [ + ('foo', NS(x=None, spam=RFile('foo'))), + ('-x foo bar', NS(x=RFile('foo'), spam=RFile('bar'))), + ('bar -x foo', NS(x=RFile('foo'), spam=RFile('bar'))), + ('-x - -', NS(x=eq_stdin, spam=eq_stdin)), + ('readonly', NS(x=None, spam=RFile('readonly'))), + ] + +class TestFileTypeDefaults(TempDirMixin, ParserTestCase): + """Test that a file is not created unless the default is needed""" + def setUp(self): + super(TestFileTypeDefaults, self).setUp() + file = open(os.path.join(self.temp_dir, 'good'), 'w') + file.write('good') + file.close() + + argument_signatures = [ + Sig('-c', type=argparse.FileType('r'), default='no-file.txt'), + ] + # should provoke no such file error + failures = [''] + # should not provoke error because default file is created + successes = [('-c good', NS(c=RFile('good')))] + + +class TestFileTypeRB(TempDirMixin, ParserTestCase): + """Test the FileType option/argument type for reading files""" + + def setUp(self): + super(TestFileTypeRB, self).setUp() + for file_name in ['foo', 'bar']: + with open(os.path.join(self.temp_dir, file_name), 'w') as file: + file.write(file_name) + + argument_signatures = [ + Sig('-x', type=argparse.FileType('rb')), + Sig('spam', type=argparse.FileType('rb')), + ] + failures = ['-x', ''] + successes = [ + ('foo', NS(x=None, spam=RFile('foo'))), + ('-x foo bar', NS(x=RFile('foo'), spam=RFile('bar'))), + ('bar -x foo', NS(x=RFile('foo'), spam=RFile('bar'))), + ('-x - -', NS(x=eq_stdin, spam=eq_stdin)), + ] + + +class WFile(object): + seen = set() + + def __init__(self, name): + self.name = name + + def __eq__(self, other): + if other not in self.seen: + text = 'Check that file is writable.' + if 'b' in other.mode: + text = text.encode('ascii') + other.write(text) + other.close() + self.seen.add(other) + return self.name == other.name + + +@unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0, + "non-root user required") +class TestFileTypeW(TempDirMixin, ParserTestCase): + """Test the FileType option/argument type for writing files""" + + def setUp(self): + super(TestFileTypeW, self).setUp() + self.create_readonly_file('readonly') + + argument_signatures = [ + Sig('-x', type=argparse.FileType('w')), + Sig('spam', type=argparse.FileType('w')), + ] + failures = ['-x', '', 'readonly'] + successes = [ + ('foo', NS(x=None, spam=WFile('foo'))), + ('-x foo bar', NS(x=WFile('foo'), spam=WFile('bar'))), + ('bar -x foo', NS(x=WFile('foo'), spam=WFile('bar'))), + ('-x - -', NS(x=eq_stdout, spam=eq_stdout)), + ] + + +class TestFileTypeWB(TempDirMixin, ParserTestCase): + + argument_signatures = [ + Sig('-x', type=argparse.FileType('wb')), + Sig('spam', type=argparse.FileType('wb')), + ] + failures = ['-x', ''] + successes = [ + ('foo', NS(x=None, spam=WFile('foo'))), + ('-x foo bar', NS(x=WFile('foo'), spam=WFile('bar'))), + ('bar -x foo', NS(x=WFile('foo'), spam=WFile('bar'))), + ('-x - -', NS(x=eq_stdout, spam=eq_stdout)), + ] + + +class TestFileTypeOpenArgs(TestCase): + """Test that open (the builtin) is correctly called""" + + def test_open_args(self): + FT = argparse.FileType + cases = [ + (FT('rb'), ('rb', -1, None, None)), + (FT('w', 1), ('w', 1, None, None)), + (FT('w', errors='replace'), ('w', -1, None, 'replace')), + (FT('wb', encoding='big5'), ('wb', -1, 'big5', None)), + (FT('w', 0, 'l1', 'strict'), ('w', 0, 'l1', 'strict')), + ] + with mock.patch('builtins.open') as m: + for type, args in cases: + type('foo') + m.assert_called_with('foo', *args) + + +class TestTypeCallable(ParserTestCase): + """Test some callables as option/argument types""" + + argument_signatures = [ + Sig('--eggs', type=complex), + Sig('spam', type=float), + ] + failures = ['a', '42j', '--eggs a', '--eggs 2i'] + successes = [ + ('--eggs=42 42', NS(eggs=42, spam=42.0)), + ('--eggs 2j -- -1.5', NS(eggs=2j, spam=-1.5)), + ('1024.675', NS(eggs=None, spam=1024.675)), + ] + + +class TestTypeUserDefined(ParserTestCase): + """Test a user-defined option/argument type""" + + class MyType(TestCase): + + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return (type(self), self.value) == (type(other), other.value) + + argument_signatures = [ + Sig('-x', type=MyType), + Sig('spam', type=MyType), + ] + failures = [] + successes = [ + ('a -x b', NS(x=MyType('b'), spam=MyType('a'))), + ('-xf g', NS(x=MyType('f'), spam=MyType('g'))), + ] + + +class TestTypeClassicClass(ParserTestCase): + """Test a classic class type""" + + class C: + + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return (type(self), self.value) == (type(other), other.value) + + argument_signatures = [ + Sig('-x', type=C), + Sig('spam', type=C), + ] + failures = [] + successes = [ + ('a -x b', NS(x=C('b'), spam=C('a'))), + ('-xf g', NS(x=C('f'), spam=C('g'))), + ] + + +class TestTypeRegistration(TestCase): + """Test a user-defined type by registering it""" + + def test(self): + + def get_my_type(string): + return 'my_type{%s}' % string + + parser = argparse.ArgumentParser() + parser.register('type', 'my_type', get_my_type) + parser.add_argument('-x', type='my_type') + parser.add_argument('y', type='my_type') + + self.assertEqual(parser.parse_args('1'.split()), + NS(x=None, y='my_type{1}')) + self.assertEqual(parser.parse_args('-x 1 42'.split()), + NS(x='my_type{1}', y='my_type{42}')) + + +# ============ +# Action tests +# ============ + +class TestActionUserDefined(ParserTestCase): + """Test a user-defined option/argument action""" + + class OptionalAction(argparse.Action): + + def __call__(self, parser, namespace, value, option_string=None): + try: + # check destination and option string + assert self.dest == 'spam', 'dest: %s' % self.dest + assert option_string == '-s', 'flag: %s' % option_string + # when option is before argument, badger=2, and when + # option is after argument, badger= + expected_ns = NS(spam=0.25) + if value in [0.125, 0.625]: + expected_ns.badger = 2 + elif value in [2.0]: + expected_ns.badger = 84 + else: + raise AssertionError('value: %s' % value) + assert expected_ns == namespace, ('expected %s, got %s' % + (expected_ns, namespace)) + except AssertionError: + e = sys.exc_info()[1] + raise ArgumentParserError('opt_action failed: %s' % e) + setattr(namespace, 'spam', value) + + class PositionalAction(argparse.Action): + + def __call__(self, parser, namespace, value, option_string=None): + try: + assert option_string is None, ('option_string: %s' % + option_string) + # check destination + assert self.dest == 'badger', 'dest: %s' % self.dest + # when argument is before option, spam=0.25, and when + # option is after argument, spam= + expected_ns = NS(badger=2) + if value in [42, 84]: + expected_ns.spam = 0.25 + elif value in [1]: + expected_ns.spam = 0.625 + elif value in [2]: + expected_ns.spam = 0.125 + else: + raise AssertionError('value: %s' % value) + assert expected_ns == namespace, ('expected %s, got %s' % + (expected_ns, namespace)) + except AssertionError: + e = sys.exc_info()[1] + raise ArgumentParserError('arg_action failed: %s' % e) + setattr(namespace, 'badger', value) + + argument_signatures = [ + Sig('-s', dest='spam', action=OptionalAction, + type=float, default=0.25), + Sig('badger', action=PositionalAction, + type=int, nargs='?', default=2), + ] + failures = [] + successes = [ + ('-s0.125', NS(spam=0.125, badger=2)), + ('42', NS(spam=0.25, badger=42)), + ('-s 0.625 1', NS(spam=0.625, badger=1)), + ('84 -s2', NS(spam=2.0, badger=84)), + ] + + +class TestActionRegistration(TestCase): + """Test a user-defined action supplied by registering it""" + + class MyAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, 'foo[%s]' % values) + + def test(self): + + parser = argparse.ArgumentParser() + parser.register('action', 'my_action', self.MyAction) + parser.add_argument('badger', action='my_action') + + self.assertEqual(parser.parse_args(['1']), NS(badger='foo[1]')) + self.assertEqual(parser.parse_args(['42']), NS(badger='foo[42]')) + + +# ================ +# Subparsers tests +# ================ + +class TestAddSubparsers(TestCase): + """Test the add_subparsers method""" + + def assertArgumentParserError(self, *args, **kwargs): + self.assertRaises(ArgumentParserError, *args, **kwargs) + + def _get_parser(self, subparser_help=False, prefix_chars=None, + aliases=False): + # create a parser with a subparsers argument + if prefix_chars: + parser = ErrorRaisingArgumentParser( + prog='PROG', description='main description', prefix_chars=prefix_chars) + parser.add_argument( + prefix_chars[0] * 2 + 'foo', action='store_true', help='foo help') + else: + parser = ErrorRaisingArgumentParser( + prog='PROG', description='main description') + parser.add_argument( + '--foo', action='store_true', help='foo help') + parser.add_argument( + 'bar', type=float, help='bar help') + + # check that only one subparsers argument can be added + subparsers_kwargs = {'required': False} + if aliases: + subparsers_kwargs['metavar'] = 'COMMAND' + subparsers_kwargs['title'] = 'commands' + else: + subparsers_kwargs['help'] = 'command help' + subparsers = parser.add_subparsers(**subparsers_kwargs) + self.assertArgumentParserError(parser.add_subparsers) + + # add first sub-parser + parser1_kwargs = dict(description='1 description') + if subparser_help: + parser1_kwargs['help'] = '1 help' + if aliases: + parser1_kwargs['aliases'] = ['1alias1', '1alias2'] + parser1 = subparsers.add_parser('1', **parser1_kwargs) + parser1.add_argument('-w', type=int, help='w help') + parser1.add_argument('x', choices='abc', help='x help') + + # add second sub-parser + parser2_kwargs = dict(description='2 description') + if subparser_help: + parser2_kwargs['help'] = '2 help' + parser2 = subparsers.add_parser('2', **parser2_kwargs) + parser2.add_argument('-y', choices='123', help='y help') + parser2.add_argument('z', type=complex, nargs='*', help='z help') + + # add third sub-parser + parser3_kwargs = dict(description='3 description') + if subparser_help: + parser3_kwargs['help'] = '3 help' + parser3 = subparsers.add_parser('3', **parser3_kwargs) + parser3.add_argument('t', type=int, help='t help') + parser3.add_argument('u', nargs='...', help='u help') + + # return the main parser + return parser + + def setUp(self): + super().setUp() + self.parser = self._get_parser() + self.command_help_parser = self._get_parser(subparser_help=True) + + def test_parse_args_failures(self): + # check some failure cases: + for args_str in ['', 'a', 'a a', '0.5 a', '0.5 1', + '0.5 1 -y', '0.5 2 -w']: + args = args_str.split() + self.assertArgumentParserError(self.parser.parse_args, args) + + def test_parse_args(self): + # check some non-failure cases: + self.assertEqual( + self.parser.parse_args('0.5 1 b -w 7'.split()), + NS(foo=False, bar=0.5, w=7, x='b'), + ) + self.assertEqual( + self.parser.parse_args('0.25 --foo 2 -y 2 3j -- -1j'.split()), + NS(foo=True, bar=0.25, y='2', z=[3j, -1j]), + ) + self.assertEqual( + self.parser.parse_args('--foo 0.125 1 c'.split()), + NS(foo=True, bar=0.125, w=None, x='c'), + ) + self.assertEqual( + self.parser.parse_args('-1.5 3 11 -- a --foo 7 -- b'.split()), + NS(foo=False, bar=-1.5, t=11, u=['a', '--foo', '7', '--', 'b']), + ) + + def test_parse_known_args(self): + self.assertEqual( + self.parser.parse_known_args('0.5 1 b -w 7'.split()), + (NS(foo=False, bar=0.5, w=7, x='b'), []), + ) + self.assertEqual( + self.parser.parse_known_args('0.5 -p 1 b -w 7'.split()), + (NS(foo=False, bar=0.5, w=7, x='b'), ['-p']), + ) + self.assertEqual( + self.parser.parse_known_args('0.5 1 b -w 7 -p'.split()), + (NS(foo=False, bar=0.5, w=7, x='b'), ['-p']), + ) + self.assertEqual( + self.parser.parse_known_args('0.5 1 b -q -rs -w 7'.split()), + (NS(foo=False, bar=0.5, w=7, x='b'), ['-q', '-rs']), + ) + self.assertEqual( + self.parser.parse_known_args('0.5 -W 1 b -X Y -w 7 Z'.split()), + (NS(foo=False, bar=0.5, w=7, x='b'), ['-W', '-X', 'Y', 'Z']), + ) + + def test_dest(self): + parser = ErrorRaisingArgumentParser() + parser.add_argument('--foo', action='store_true') + subparsers = parser.add_subparsers(dest='bar') + parser1 = subparsers.add_parser('1') + parser1.add_argument('baz') + self.assertEqual(NS(foo=False, bar='1', baz='2'), + parser.parse_args('1 2'.split())) + + def _test_required_subparsers(self, parser): + # Should parse the sub command + ret = parser.parse_args(['run']) + self.assertEqual(ret.command, 'run') + + # Error when the command is missing + self.assertArgumentParserError(parser.parse_args, ()) + + def test_required_subparsers_via_attribute(self): + parser = ErrorRaisingArgumentParser() + subparsers = parser.add_subparsers(dest='command') + subparsers.required = True + subparsers.add_parser('run') + self._test_required_subparsers(parser) + + def test_required_subparsers_via_kwarg(self): + parser = ErrorRaisingArgumentParser() + subparsers = parser.add_subparsers(dest='command', required=True) + subparsers.add_parser('run') + self._test_required_subparsers(parser) + + def test_required_subparsers_default(self): + parser = ErrorRaisingArgumentParser() + subparsers = parser.add_subparsers(dest='command') + subparsers.add_parser('run') + # No error here + ret = parser.parse_args(()) + self.assertIsNone(ret.command) + + def test_optional_subparsers(self): + parser = ErrorRaisingArgumentParser() + subparsers = parser.add_subparsers(dest='command', required=False) + subparsers.add_parser('run') + # No error here + ret = parser.parse_args(()) + self.assertIsNone(ret.command) + + def test_help(self): + self.assertEqual(self.parser.format_usage(), + 'usage: PROG [-h] [--foo] bar {1,2,3} ...\n') + self.assertEqual(self.parser.format_help(), textwrap.dedent('''\ + usage: PROG [-h] [--foo] bar {1,2,3} ... + + main description + + positional arguments: + bar bar help + {1,2,3} command help + + optional arguments: + -h, --help show this help message and exit + --foo foo help + ''')) + + def test_help_extra_prefix_chars(self): + # Make sure - is still used for help if it is a non-first prefix char + parser = self._get_parser(prefix_chars='+:-') + self.assertEqual(parser.format_usage(), + 'usage: PROG [-h] [++foo] bar {1,2,3} ...\n') + self.assertEqual(parser.format_help(), textwrap.dedent('''\ + usage: PROG [-h] [++foo] bar {1,2,3} ... + + main description + + positional arguments: + bar bar help + {1,2,3} command help + + optional arguments: + -h, --help show this help message and exit + ++foo foo help + ''')) + + def test_help_non_breaking_spaces(self): + parser = ErrorRaisingArgumentParser( + prog='PROG', description='main description') + parser.add_argument( + "--non-breaking", action='store_false', + help='help message containing non-breaking spaces shall not ' + 'wrap\N{NO-BREAK SPACE}at non-breaking spaces') + self.assertEqual(parser.format_help(), textwrap.dedent('''\ + usage: PROG [-h] [--non-breaking] + + main description + + optional arguments: + -h, --help show this help message and exit + --non-breaking help message containing non-breaking spaces shall not + wrap\N{NO-BREAK SPACE}at non-breaking spaces + ''')) + + def test_help_alternate_prefix_chars(self): + parser = self._get_parser(prefix_chars='+:/') + self.assertEqual(parser.format_usage(), + 'usage: PROG [+h] [++foo] bar {1,2,3} ...\n') + self.assertEqual(parser.format_help(), textwrap.dedent('''\ + usage: PROG [+h] [++foo] bar {1,2,3} ... + + main description + + positional arguments: + bar bar help + {1,2,3} command help + + optional arguments: + +h, ++help show this help message and exit + ++foo foo help + ''')) + + def test_parser_command_help(self): + self.assertEqual(self.command_help_parser.format_usage(), + 'usage: PROG [-h] [--foo] bar {1,2,3} ...\n') + self.assertEqual(self.command_help_parser.format_help(), + textwrap.dedent('''\ + usage: PROG [-h] [--foo] bar {1,2,3} ... + + main description + + positional arguments: + bar bar help + {1,2,3} command help + 1 1 help + 2 2 help + 3 3 help + + optional arguments: + -h, --help show this help message and exit + --foo foo help + ''')) + + def test_subparser_title_help(self): + parser = ErrorRaisingArgumentParser(prog='PROG', + description='main description') + parser.add_argument('--foo', action='store_true', help='foo help') + parser.add_argument('bar', help='bar help') + subparsers = parser.add_subparsers(title='subcommands', + description='command help', + help='additional text') + parser1 = subparsers.add_parser('1') + parser2 = subparsers.add_parser('2') + self.assertEqual(parser.format_usage(), + 'usage: PROG [-h] [--foo] bar {1,2} ...\n') + self.assertEqual(parser.format_help(), textwrap.dedent('''\ + usage: PROG [-h] [--foo] bar {1,2} ... + + main description + + positional arguments: + bar bar help + + optional arguments: + -h, --help show this help message and exit + --foo foo help + + subcommands: + command help + + {1,2} additional text + ''')) + + def _test_subparser_help(self, args_str, expected_help): + with self.assertRaises(ArgumentParserError) as cm: + self.parser.parse_args(args_str.split()) + self.assertEqual(expected_help, cm.exception.stdout) + + def test_subparser1_help(self): + self._test_subparser_help('5.0 1 -h', textwrap.dedent('''\ + usage: PROG bar 1 [-h] [-w W] {a,b,c} + + 1 description + + positional arguments: + {a,b,c} x help + + optional arguments: + -h, --help show this help message and exit + -w W w help + ''')) + + def test_subparser2_help(self): + self._test_subparser_help('5.0 2 -h', textwrap.dedent('''\ + usage: PROG bar 2 [-h] [-y {1,2,3}] [z [z ...]] + + 2 description + + positional arguments: + z z help + + optional arguments: + -h, --help show this help message and exit + -y {1,2,3} y help + ''')) + + def test_alias_invocation(self): + parser = self._get_parser(aliases=True) + self.assertEqual( + parser.parse_known_args('0.5 1alias1 b'.split()), + (NS(foo=False, bar=0.5, w=None, x='b'), []), + ) + self.assertEqual( + parser.parse_known_args('0.5 1alias2 b'.split()), + (NS(foo=False, bar=0.5, w=None, x='b'), []), + ) + + def test_error_alias_invocation(self): + parser = self._get_parser(aliases=True) + self.assertArgumentParserError(parser.parse_args, + '0.5 1alias3 b'.split()) + + def test_alias_help(self): + parser = self._get_parser(aliases=True, subparser_help=True) + self.maxDiff = None + self.assertEqual(parser.format_help(), textwrap.dedent("""\ + usage: PROG [-h] [--foo] bar COMMAND ... + + main description + + positional arguments: + bar bar help + + optional arguments: + -h, --help show this help message and exit + --foo foo help + + commands: + COMMAND + 1 (1alias1, 1alias2) + 1 help + 2 2 help + 3 3 help + """)) + +# ============ +# Groups tests +# ============ + +class TestPositionalsGroups(TestCase): + """Tests that order of group positionals matches construction order""" + + def test_nongroup_first(self): + parser = ErrorRaisingArgumentParser() + parser.add_argument('foo') + group = parser.add_argument_group('g') + group.add_argument('bar') + parser.add_argument('baz') + expected = NS(foo='1', bar='2', baz='3') + result = parser.parse_args('1 2 3'.split()) + self.assertEqual(expected, result) + + def test_group_first(self): + parser = ErrorRaisingArgumentParser() + group = parser.add_argument_group('xxx') + group.add_argument('foo') + parser.add_argument('bar') + parser.add_argument('baz') + expected = NS(foo='1', bar='2', baz='3') + result = parser.parse_args('1 2 3'.split()) + self.assertEqual(expected, result) + + def test_interleaved_groups(self): + parser = ErrorRaisingArgumentParser() + group = parser.add_argument_group('xxx') + parser.add_argument('foo') + group.add_argument('bar') + parser.add_argument('baz') + group = parser.add_argument_group('yyy') + group.add_argument('frell') + expected = NS(foo='1', bar='2', baz='3', frell='4') + result = parser.parse_args('1 2 3 4'.split()) + self.assertEqual(expected, result) + +# =================== +# Parent parser tests +# =================== + +class TestParentParsers(TestCase): + """Tests that parsers can be created with parent parsers""" + + def assertArgumentParserError(self, *args, **kwargs): + self.assertRaises(ArgumentParserError, *args, **kwargs) + + def setUp(self): + super().setUp() + self.wxyz_parent = ErrorRaisingArgumentParser(add_help=False) + self.wxyz_parent.add_argument('--w') + x_group = self.wxyz_parent.add_argument_group('x') + x_group.add_argument('-y') + self.wxyz_parent.add_argument('z') + + self.abcd_parent = ErrorRaisingArgumentParser(add_help=False) + self.abcd_parent.add_argument('a') + self.abcd_parent.add_argument('-b') + c_group = self.abcd_parent.add_argument_group('c') + c_group.add_argument('--d') + + self.w_parent = ErrorRaisingArgumentParser(add_help=False) + self.w_parent.add_argument('--w') + + self.z_parent = ErrorRaisingArgumentParser(add_help=False) + self.z_parent.add_argument('z') + + # parents with mutually exclusive groups + self.ab_mutex_parent = ErrorRaisingArgumentParser(add_help=False) + group = self.ab_mutex_parent.add_mutually_exclusive_group() + group.add_argument('-a', action='store_true') + group.add_argument('-b', action='store_true') + + self.main_program = os.path.basename(sys.argv[0]) + + def test_single_parent(self): + parser = ErrorRaisingArgumentParser(parents=[self.wxyz_parent]) + self.assertEqual(parser.parse_args('-y 1 2 --w 3'.split()), + NS(w='3', y='1', z='2')) + + def test_single_parent_mutex(self): + self._test_mutex_ab(self.ab_mutex_parent.parse_args) + parser = ErrorRaisingArgumentParser(parents=[self.ab_mutex_parent]) + self._test_mutex_ab(parser.parse_args) + + def test_single_granparent_mutex(self): + parents = [self.ab_mutex_parent] + parser = ErrorRaisingArgumentParser(add_help=False, parents=parents) + parser = ErrorRaisingArgumentParser(parents=[parser]) + self._test_mutex_ab(parser.parse_args) + + def _test_mutex_ab(self, parse_args): + self.assertEqual(parse_args([]), NS(a=False, b=False)) + self.assertEqual(parse_args(['-a']), NS(a=True, b=False)) + self.assertEqual(parse_args(['-b']), NS(a=False, b=True)) + self.assertArgumentParserError(parse_args, ['-a', '-b']) + self.assertArgumentParserError(parse_args, ['-b', '-a']) + self.assertArgumentParserError(parse_args, ['-c']) + self.assertArgumentParserError(parse_args, ['-a', '-c']) + self.assertArgumentParserError(parse_args, ['-b', '-c']) + + def test_multiple_parents(self): + parents = [self.abcd_parent, self.wxyz_parent] + parser = ErrorRaisingArgumentParser(parents=parents) + self.assertEqual(parser.parse_args('--d 1 --w 2 3 4'.split()), + NS(a='3', b=None, d='1', w='2', y=None, z='4')) + + def test_multiple_parents_mutex(self): + parents = [self.ab_mutex_parent, self.wxyz_parent] + parser = ErrorRaisingArgumentParser(parents=parents) + self.assertEqual(parser.parse_args('-a --w 2 3'.split()), + NS(a=True, b=False, w='2', y=None, z='3')) + self.assertArgumentParserError( + parser.parse_args, '-a --w 2 3 -b'.split()) + self.assertArgumentParserError( + parser.parse_args, '-a -b --w 2 3'.split()) + + def test_conflicting_parents(self): + self.assertRaises( + argparse.ArgumentError, + argparse.ArgumentParser, + parents=[self.w_parent, self.wxyz_parent]) + + def test_conflicting_parents_mutex(self): + self.assertRaises( + argparse.ArgumentError, + argparse.ArgumentParser, + parents=[self.abcd_parent, self.ab_mutex_parent]) + + def test_same_argument_name_parents(self): + parents = [self.wxyz_parent, self.z_parent] + parser = ErrorRaisingArgumentParser(parents=parents) + self.assertEqual(parser.parse_args('1 2'.split()), + NS(w=None, y=None, z='2')) + + def test_subparser_parents(self): + parser = ErrorRaisingArgumentParser() + subparsers = parser.add_subparsers() + abcde_parser = subparsers.add_parser('bar', parents=[self.abcd_parent]) + abcde_parser.add_argument('e') + self.assertEqual(parser.parse_args('bar -b 1 --d 2 3 4'.split()), + NS(a='3', b='1', d='2', e='4')) + + def test_subparser_parents_mutex(self): + parser = ErrorRaisingArgumentParser() + subparsers = parser.add_subparsers() + parents = [self.ab_mutex_parent] + abc_parser = subparsers.add_parser('foo', parents=parents) + c_group = abc_parser.add_argument_group('c_group') + c_group.add_argument('c') + parents = [self.wxyz_parent, self.ab_mutex_parent] + wxyzabe_parser = subparsers.add_parser('bar', parents=parents) + wxyzabe_parser.add_argument('e') + self.assertEqual(parser.parse_args('foo -a 4'.split()), + NS(a=True, b=False, c='4')) + self.assertEqual(parser.parse_args('bar -b --w 2 3 4'.split()), + NS(a=False, b=True, w='2', y=None, z='3', e='4')) + self.assertArgumentParserError( + parser.parse_args, 'foo -a -b 4'.split()) + self.assertArgumentParserError( + parser.parse_args, 'bar -b -a 4'.split()) + + def test_parent_help(self): + parents = [self.abcd_parent, self.wxyz_parent] + parser = ErrorRaisingArgumentParser(parents=parents) + parser_help = parser.format_help() + progname = self.main_program + self.assertEqual(parser_help, textwrap.dedent('''\ + usage: {}{}[-h] [-b B] [--d D] [--w W] [-y Y] a z + + positional arguments: + a + z + + optional arguments: + -h, --help show this help message and exit + -b B + --w W + + c: + --d D + + x: + -y Y + '''.format(progname, ' ' if progname else '' ))) + + def test_groups_parents(self): + parent = ErrorRaisingArgumentParser(add_help=False) + g = parent.add_argument_group(title='g', description='gd') + g.add_argument('-w') + g.add_argument('-x') + m = parent.add_mutually_exclusive_group() + m.add_argument('-y') + m.add_argument('-z') + parser = ErrorRaisingArgumentParser(parents=[parent]) + + self.assertRaises(ArgumentParserError, parser.parse_args, + ['-y', 'Y', '-z', 'Z']) + + parser_help = parser.format_help() + progname = self.main_program + self.assertEqual(parser_help, textwrap.dedent('''\ + usage: {}{}[-h] [-w W] [-x X] [-y Y | -z Z] + + optional arguments: + -h, --help show this help message and exit + -y Y + -z Z + + g: + gd + + -w W + -x X + '''.format(progname, ' ' if progname else '' ))) + +# ============================== +# Mutually exclusive group tests +# ============================== + +class TestMutuallyExclusiveGroupErrors(TestCase): + + def test_invalid_add_argument_group(self): + parser = ErrorRaisingArgumentParser() + raises = self.assertRaises + raises(TypeError, parser.add_mutually_exclusive_group, title='foo') + + def test_invalid_add_argument(self): + parser = ErrorRaisingArgumentParser() + group = parser.add_mutually_exclusive_group() + add_argument = group.add_argument + raises = self.assertRaises + raises(ValueError, add_argument, '--foo', required=True) + raises(ValueError, add_argument, 'bar') + raises(ValueError, add_argument, 'bar', nargs='+') + raises(ValueError, add_argument, 'bar', nargs=1) + raises(ValueError, add_argument, 'bar', nargs=argparse.PARSER) + + def test_help(self): + parser = ErrorRaisingArgumentParser(prog='PROG') + group1 = parser.add_mutually_exclusive_group() + group1.add_argument('--foo', action='store_true') + group1.add_argument('--bar', action='store_false') + group2 = parser.add_mutually_exclusive_group() + group2.add_argument('--soup', action='store_true') + group2.add_argument('--nuts', action='store_false') + expected = '''\ + usage: PROG [-h] [--foo | --bar] [--soup | --nuts] + + optional arguments: + -h, --help show this help message and exit + --foo + --bar + --soup + --nuts + ''' + self.assertEqual(parser.format_help(), textwrap.dedent(expected)) + +class MEMixin(object): + + def test_failures_when_not_required(self): + parse_args = self.get_parser(required=False).parse_args + error = ArgumentParserError + for args_string in self.failures: + self.assertRaises(error, parse_args, args_string.split()) + + def test_failures_when_required(self): + parse_args = self.get_parser(required=True).parse_args + error = ArgumentParserError + for args_string in self.failures + ['']: + self.assertRaises(error, parse_args, args_string.split()) + + def test_successes_when_not_required(self): + parse_args = self.get_parser(required=False).parse_args + successes = self.successes + self.successes_when_not_required + for args_string, expected_ns in successes: + actual_ns = parse_args(args_string.split()) + self.assertEqual(actual_ns, expected_ns) + + def test_successes_when_required(self): + parse_args = self.get_parser(required=True).parse_args + for args_string, expected_ns in self.successes: + actual_ns = parse_args(args_string.split()) + self.assertEqual(actual_ns, expected_ns) + + def test_usage_when_not_required(self): + format_usage = self.get_parser(required=False).format_usage + expected_usage = self.usage_when_not_required + self.assertEqual(format_usage(), textwrap.dedent(expected_usage)) + + def test_usage_when_required(self): + format_usage = self.get_parser(required=True).format_usage + expected_usage = self.usage_when_required + self.assertEqual(format_usage(), textwrap.dedent(expected_usage)) + + def test_help_when_not_required(self): + format_help = self.get_parser(required=False).format_help + help = self.usage_when_not_required + self.help + self.assertEqual(format_help(), textwrap.dedent(help)) + + def test_help_when_required(self): + format_help = self.get_parser(required=True).format_help + help = self.usage_when_required + self.help + self.assertEqual(format_help(), textwrap.dedent(help)) + + +class TestMutuallyExclusiveSimple(MEMixin, TestCase): + + def get_parser(self, required=None): + parser = ErrorRaisingArgumentParser(prog='PROG') + group = parser.add_mutually_exclusive_group(required=required) + group.add_argument('--bar', help='bar help') + group.add_argument('--baz', nargs='?', const='Z', help='baz help') + return parser + + failures = ['--bar X --baz Y', '--bar X --baz'] + successes = [ + ('--bar X', NS(bar='X', baz=None)), + ('--bar X --bar Z', NS(bar='Z', baz=None)), + ('--baz Y', NS(bar=None, baz='Y')), + ('--baz', NS(bar=None, baz='Z')), + ] + successes_when_not_required = [ + ('', NS(bar=None, baz=None)), + ] + + usage_when_not_required = '''\ + usage: PROG [-h] [--bar BAR | --baz [BAZ]] + ''' + usage_when_required = '''\ + usage: PROG [-h] (--bar BAR | --baz [BAZ]) + ''' + help = '''\ + + optional arguments: + -h, --help show this help message and exit + --bar BAR bar help + --baz [BAZ] baz help + ''' + + +class TestMutuallyExclusiveLong(MEMixin, TestCase): + + def get_parser(self, required=None): + parser = ErrorRaisingArgumentParser(prog='PROG') + parser.add_argument('--abcde', help='abcde help') + parser.add_argument('--fghij', help='fghij help') + group = parser.add_mutually_exclusive_group(required=required) + group.add_argument('--klmno', help='klmno help') + group.add_argument('--pqrst', help='pqrst help') + return parser + + failures = ['--klmno X --pqrst Y'] + successes = [ + ('--klmno X', NS(abcde=None, fghij=None, klmno='X', pqrst=None)), + ('--abcde Y --klmno X', + NS(abcde='Y', fghij=None, klmno='X', pqrst=None)), + ('--pqrst X', NS(abcde=None, fghij=None, klmno=None, pqrst='X')), + ('--pqrst X --fghij Y', + NS(abcde=None, fghij='Y', klmno=None, pqrst='X')), + ] + successes_when_not_required = [ + ('', NS(abcde=None, fghij=None, klmno=None, pqrst=None)), + ] + + usage_when_not_required = '''\ + usage: PROG [-h] [--abcde ABCDE] [--fghij FGHIJ] + [--klmno KLMNO | --pqrst PQRST] + ''' + usage_when_required = '''\ + usage: PROG [-h] [--abcde ABCDE] [--fghij FGHIJ] + (--klmno KLMNO | --pqrst PQRST) + ''' + help = '''\ + + optional arguments: + -h, --help show this help message and exit + --abcde ABCDE abcde help + --fghij FGHIJ fghij help + --klmno KLMNO klmno help + --pqrst PQRST pqrst help + ''' + + +class TestMutuallyExclusiveFirstSuppressed(MEMixin, TestCase): + + def get_parser(self, required): + parser = ErrorRaisingArgumentParser(prog='PROG') + group = parser.add_mutually_exclusive_group(required=required) + group.add_argument('-x', help=argparse.SUPPRESS) + group.add_argument('-y', action='store_false', help='y help') + return parser + + failures = ['-x X -y'] + successes = [ + ('-x X', NS(x='X', y=True)), + ('-x X -x Y', NS(x='Y', y=True)), + ('-y', NS(x=None, y=False)), + ] + successes_when_not_required = [ + ('', NS(x=None, y=True)), + ] + + usage_when_not_required = '''\ + usage: PROG [-h] [-y] + ''' + usage_when_required = '''\ + usage: PROG [-h] -y + ''' + help = '''\ + + optional arguments: + -h, --help show this help message and exit + -y y help + ''' + + +class TestMutuallyExclusiveManySuppressed(MEMixin, TestCase): + + def get_parser(self, required): + parser = ErrorRaisingArgumentParser(prog='PROG') + group = parser.add_mutually_exclusive_group(required=required) + add = group.add_argument + add('--spam', action='store_true', help=argparse.SUPPRESS) + add('--badger', action='store_false', help=argparse.SUPPRESS) + add('--bladder', help=argparse.SUPPRESS) + return parser + + failures = [ + '--spam --badger', + '--badger --bladder B', + '--bladder B --spam', + ] + successes = [ + ('--spam', NS(spam=True, badger=True, bladder=None)), + ('--badger', NS(spam=False, badger=False, bladder=None)), + ('--bladder B', NS(spam=False, badger=True, bladder='B')), + ('--spam --spam', NS(spam=True, badger=True, bladder=None)), + ] + successes_when_not_required = [ + ('', NS(spam=False, badger=True, bladder=None)), + ] + + usage_when_required = usage_when_not_required = '''\ + usage: PROG [-h] + ''' + help = '''\ + + optional arguments: + -h, --help show this help message and exit + ''' + + +class TestMutuallyExclusiveOptionalAndPositional(MEMixin, TestCase): + + def get_parser(self, required): + parser = ErrorRaisingArgumentParser(prog='PROG') + group = parser.add_mutually_exclusive_group(required=required) + group.add_argument('--foo', action='store_true', help='FOO') + group.add_argument('--spam', help='SPAM') + group.add_argument('badger', nargs='*', default='X', help='BADGER') + return parser + + failures = [ + '--foo --spam S', + '--spam S X', + 'X --foo', + 'X Y Z --spam S', + '--foo X Y', + ] + successes = [ + ('--foo', NS(foo=True, spam=None, badger='X')), + ('--spam S', NS(foo=False, spam='S', badger='X')), + ('X', NS(foo=False, spam=None, badger=['X'])), + ('X Y Z', NS(foo=False, spam=None, badger=['X', 'Y', 'Z'])), + ] + successes_when_not_required = [ + ('', NS(foo=False, spam=None, badger='X')), + ] + + usage_when_not_required = '''\ + usage: PROG [-h] [--foo | --spam SPAM | badger [badger ...]] + ''' + usage_when_required = '''\ + usage: PROG [-h] (--foo | --spam SPAM | badger [badger ...]) + ''' + help = '''\ + + positional arguments: + badger BADGER + + optional arguments: + -h, --help show this help message and exit + --foo FOO + --spam SPAM SPAM + ''' + + +class TestMutuallyExclusiveOptionalsMixed(MEMixin, TestCase): + + def get_parser(self, required): + parser = ErrorRaisingArgumentParser(prog='PROG') + parser.add_argument('-x', action='store_true', help='x help') + group = parser.add_mutually_exclusive_group(required=required) + group.add_argument('-a', action='store_true', help='a help') + group.add_argument('-b', action='store_true', help='b help') + parser.add_argument('-y', action='store_true', help='y help') + group.add_argument('-c', action='store_true', help='c help') + return parser + + failures = ['-a -b', '-b -c', '-a -c', '-a -b -c'] + successes = [ + ('-a', NS(a=True, b=False, c=False, x=False, y=False)), + ('-b', NS(a=False, b=True, c=False, x=False, y=False)), + ('-c', NS(a=False, b=False, c=True, x=False, y=False)), + ('-a -x', NS(a=True, b=False, c=False, x=True, y=False)), + ('-y -b', NS(a=False, b=True, c=False, x=False, y=True)), + ('-x -y -c', NS(a=False, b=False, c=True, x=True, y=True)), + ] + successes_when_not_required = [ + ('', NS(a=False, b=False, c=False, x=False, y=False)), + ('-x', NS(a=False, b=False, c=False, x=True, y=False)), + ('-y', NS(a=False, b=False, c=False, x=False, y=True)), + ] + + usage_when_required = usage_when_not_required = '''\ + usage: PROG [-h] [-x] [-a] [-b] [-y] [-c] + ''' + help = '''\ + + optional arguments: + -h, --help show this help message and exit + -x x help + -a a help + -b b help + -y y help + -c c help + ''' + + +class TestMutuallyExclusiveInGroup(MEMixin, TestCase): + + def get_parser(self, required=None): + parser = ErrorRaisingArgumentParser(prog='PROG') + titled_group = parser.add_argument_group( + title='Titled group', description='Group description') + mutex_group = \ + titled_group.add_mutually_exclusive_group(required=required) + mutex_group.add_argument('--bar', help='bar help') + mutex_group.add_argument('--baz', help='baz help') + return parser + + failures = ['--bar X --baz Y', '--baz X --bar Y'] + successes = [ + ('--bar X', NS(bar='X', baz=None)), + ('--baz Y', NS(bar=None, baz='Y')), + ] + successes_when_not_required = [ + ('', NS(bar=None, baz=None)), + ] + + usage_when_not_required = '''\ + usage: PROG [-h] [--bar BAR | --baz BAZ] + ''' + usage_when_required = '''\ + usage: PROG [-h] (--bar BAR | --baz BAZ) + ''' + help = '''\ + + optional arguments: + -h, --help show this help message and exit + + Titled group: + Group description + + --bar BAR bar help + --baz BAZ baz help + ''' + + +class TestMutuallyExclusiveOptionalsAndPositionalsMixed(MEMixin, TestCase): + + def get_parser(self, required): + parser = ErrorRaisingArgumentParser(prog='PROG') + parser.add_argument('x', help='x help') + parser.add_argument('-y', action='store_true', help='y help') + group = parser.add_mutually_exclusive_group(required=required) + group.add_argument('a', nargs='?', help='a help') + group.add_argument('-b', action='store_true', help='b help') + group.add_argument('-c', action='store_true', help='c help') + return parser + + failures = ['X A -b', '-b -c', '-c X A'] + successes = [ + ('X A', NS(a='A', b=False, c=False, x='X', y=False)), + ('X -b', NS(a=None, b=True, c=False, x='X', y=False)), + ('X -c', NS(a=None, b=False, c=True, x='X', y=False)), + ('X A -y', NS(a='A', b=False, c=False, x='X', y=True)), + ('X -y -b', NS(a=None, b=True, c=False, x='X', y=True)), + ] + successes_when_not_required = [ + ('X', NS(a=None, b=False, c=False, x='X', y=False)), + ('X -y', NS(a=None, b=False, c=False, x='X', y=True)), + ] + + usage_when_required = usage_when_not_required = '''\ + usage: PROG [-h] [-y] [-b] [-c] x [a] + ''' + help = '''\ + + positional arguments: + x x help + a a help + + optional arguments: + -h, --help show this help message and exit + -y y help + -b b help + -c c help + ''' + +# ================================================= +# Mutually exclusive group in parent parser tests +# ================================================= + +class MEPBase(object): + + def get_parser(self, required=None): + parent = super(MEPBase, self).get_parser(required=required) + parser = ErrorRaisingArgumentParser( + prog=parent.prog, add_help=False, parents=[parent]) + return parser + + +class TestMutuallyExclusiveGroupErrorsParent( + MEPBase, TestMutuallyExclusiveGroupErrors): + pass + + +class TestMutuallyExclusiveSimpleParent( + MEPBase, TestMutuallyExclusiveSimple): + pass + + +class TestMutuallyExclusiveLongParent( + MEPBase, TestMutuallyExclusiveLong): + pass + + +class TestMutuallyExclusiveFirstSuppressedParent( + MEPBase, TestMutuallyExclusiveFirstSuppressed): + pass + + +class TestMutuallyExclusiveManySuppressedParent( + MEPBase, TestMutuallyExclusiveManySuppressed): + pass + + +class TestMutuallyExclusiveOptionalAndPositionalParent( + MEPBase, TestMutuallyExclusiveOptionalAndPositional): + pass + + +class TestMutuallyExclusiveOptionalsMixedParent( + MEPBase, TestMutuallyExclusiveOptionalsMixed): + pass + + +class TestMutuallyExclusiveOptionalsAndPositionalsMixedParent( + MEPBase, TestMutuallyExclusiveOptionalsAndPositionalsMixed): + pass + +# ================= +# Set default tests +# ================= + +class TestSetDefaults(TestCase): + + def test_set_defaults_no_args(self): + parser = ErrorRaisingArgumentParser() + parser.set_defaults(x='foo') + parser.set_defaults(y='bar', z=1) + self.assertEqual(NS(x='foo', y='bar', z=1), + parser.parse_args([])) + self.assertEqual(NS(x='foo', y='bar', z=1), + parser.parse_args([], NS())) + self.assertEqual(NS(x='baz', y='bar', z=1), + parser.parse_args([], NS(x='baz'))) + self.assertEqual(NS(x='baz', y='bar', z=2), + parser.parse_args([], NS(x='baz', z=2))) + + def test_set_defaults_with_args(self): + parser = ErrorRaisingArgumentParser() + parser.set_defaults(x='foo', y='bar') + parser.add_argument('-x', default='xfoox') + self.assertEqual(NS(x='xfoox', y='bar'), + parser.parse_args([])) + self.assertEqual(NS(x='xfoox', y='bar'), + parser.parse_args([], NS())) + self.assertEqual(NS(x='baz', y='bar'), + parser.parse_args([], NS(x='baz'))) + self.assertEqual(NS(x='1', y='bar'), + parser.parse_args('-x 1'.split())) + self.assertEqual(NS(x='1', y='bar'), + parser.parse_args('-x 1'.split(), NS())) + self.assertEqual(NS(x='1', y='bar'), + parser.parse_args('-x 1'.split(), NS(x='baz'))) + + def test_set_defaults_subparsers(self): + parser = ErrorRaisingArgumentParser() + parser.set_defaults(x='foo') + subparsers = parser.add_subparsers() + parser_a = subparsers.add_parser('a') + parser_a.set_defaults(y='bar') + self.assertEqual(NS(x='foo', y='bar'), + parser.parse_args('a'.split())) + + def test_set_defaults_parents(self): + parent = ErrorRaisingArgumentParser(add_help=False) + parent.set_defaults(x='foo') + parser = ErrorRaisingArgumentParser(parents=[parent]) + self.assertEqual(NS(x='foo'), parser.parse_args([])) + + def test_set_defaults_on_parent_and_subparser(self): + parser = argparse.ArgumentParser() + xparser = parser.add_subparsers().add_parser('X') + parser.set_defaults(foo=1) + xparser.set_defaults(foo=2) + self.assertEqual(NS(foo=2), parser.parse_args(['X'])) + + def test_set_defaults_same_as_add_argument(self): + parser = ErrorRaisingArgumentParser() + parser.set_defaults(w='W', x='X', y='Y', z='Z') + parser.add_argument('-w') + parser.add_argument('-x', default='XX') + parser.add_argument('y', nargs='?') + parser.add_argument('z', nargs='?', default='ZZ') + + # defaults set previously + self.assertEqual(NS(w='W', x='XX', y='Y', z='ZZ'), + parser.parse_args([])) + + # reset defaults + parser.set_defaults(w='WW', x='X', y='YY', z='Z') + self.assertEqual(NS(w='WW', x='X', y='YY', z='Z'), + parser.parse_args([])) + + def test_set_defaults_same_as_add_argument_group(self): + parser = ErrorRaisingArgumentParser() + parser.set_defaults(w='W', x='X', y='Y', z='Z') + group = parser.add_argument_group('foo') + group.add_argument('-w') + group.add_argument('-x', default='XX') + group.add_argument('y', nargs='?') + group.add_argument('z', nargs='?', default='ZZ') + + + # defaults set previously + self.assertEqual(NS(w='W', x='XX', y='Y', z='ZZ'), + parser.parse_args([])) + + # reset defaults + parser.set_defaults(w='WW', x='X', y='YY', z='Z') + self.assertEqual(NS(w='WW', x='X', y='YY', z='Z'), + parser.parse_args([])) + +# ================= +# Get default tests +# ================= + +class TestGetDefault(TestCase): + + def test_get_default(self): + parser = ErrorRaisingArgumentParser() + self.assertIsNone(parser.get_default("foo")) + self.assertIsNone(parser.get_default("bar")) + + parser.add_argument("--foo") + self.assertIsNone(parser.get_default("foo")) + self.assertIsNone(parser.get_default("bar")) + + parser.add_argument("--bar", type=int, default=42) + self.assertIsNone(parser.get_default("foo")) + self.assertEqual(42, parser.get_default("bar")) + + parser.set_defaults(foo="badger") + self.assertEqual("badger", parser.get_default("foo")) + self.assertEqual(42, parser.get_default("bar")) + +# ========================== +# Namespace 'contains' tests +# ========================== + +class TestNamespaceContainsSimple(TestCase): + + def test_empty(self): + ns = argparse.Namespace() + self.assertNotIn('', ns) + self.assertNotIn('x', ns) + + def test_non_empty(self): + ns = argparse.Namespace(x=1, y=2) + self.assertNotIn('', ns) + self.assertIn('x', ns) + self.assertIn('y', ns) + self.assertNotIn('xx', ns) + self.assertNotIn('z', ns) + +# ===================== +# Help formatting tests +# ===================== + +class TestHelpFormattingMetaclass(type): + + def __init__(cls, name, bases, bodydict): + if name == 'HelpTestCase': + return + + class AddTests(object): + + def __init__(self, test_class, func_suffix, std_name): + self.func_suffix = func_suffix + self.std_name = std_name + + for test_func in [self.test_format, + self.test_print, + self.test_print_file]: + test_name = '%s_%s' % (test_func.__name__, func_suffix) + + def test_wrapper(self, test_func=test_func): + test_func(self) + try: + test_wrapper.__name__ = test_name + except TypeError: + pass + setattr(test_class, test_name, test_wrapper) + + def _get_parser(self, tester): + parser = argparse.ArgumentParser( + *tester.parser_signature.args, + **tester.parser_signature.kwargs) + for argument_sig in getattr(tester, 'argument_signatures', []): + parser.add_argument(*argument_sig.args, + **argument_sig.kwargs) + group_sigs = getattr(tester, 'argument_group_signatures', []) + for group_sig, argument_sigs in group_sigs: + group = parser.add_argument_group(*group_sig.args, + **group_sig.kwargs) + for argument_sig in argument_sigs: + group.add_argument(*argument_sig.args, + **argument_sig.kwargs) + subparsers_sigs = getattr(tester, 'subparsers_signatures', []) + if subparsers_sigs: + subparsers = parser.add_subparsers() + for subparser_sig in subparsers_sigs: + subparsers.add_parser(*subparser_sig.args, + **subparser_sig.kwargs) + return parser + + def _test(self, tester, parser_text): + expected_text = getattr(tester, self.func_suffix) + expected_text = textwrap.dedent(expected_text) + tester.assertEqual(expected_text, parser_text) + + def test_format(self, tester): + parser = self._get_parser(tester) + format = getattr(parser, 'format_%s' % self.func_suffix) + self._test(tester, format()) + + def test_print(self, tester): + parser = self._get_parser(tester) + print_ = getattr(parser, 'print_%s' % self.func_suffix) + old_stream = getattr(sys, self.std_name) + setattr(sys, self.std_name, StdIOBuffer()) + try: + print_() + parser_text = getattr(sys, self.std_name).getvalue() + finally: + setattr(sys, self.std_name, old_stream) + self._test(tester, parser_text) + + def test_print_file(self, tester): + parser = self._get_parser(tester) + print_ = getattr(parser, 'print_%s' % self.func_suffix) + sfile = StdIOBuffer() + print_(sfile) + parser_text = sfile.getvalue() + self._test(tester, parser_text) + + # add tests for {format,print}_{usage,help} + for func_suffix, std_name in [('usage', 'stdout'), + ('help', 'stdout')]: + AddTests(cls, func_suffix, std_name) + +bases = TestCase, +HelpTestCase = TestHelpFormattingMetaclass('HelpTestCase', bases, {}) + + +class TestHelpBiggerOptionals(HelpTestCase): + """Make sure that argument help aligns when options are longer""" + + parser_signature = Sig(prog='PROG', description='DESCRIPTION', + epilog='EPILOG') + argument_signatures = [ + Sig('-v', '--version', action='version', version='0.1'), + Sig('-x', action='store_true', help='X HELP'), + Sig('--y', help='Y HELP'), + Sig('foo', help='FOO HELP'), + Sig('bar', help='BAR HELP'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-v] [-x] [--y Y] foo bar + ''' + help = usage + '''\ + + DESCRIPTION + + positional arguments: + foo FOO HELP + bar BAR HELP + + optional arguments: + -h, --help show this help message and exit + -v, --version show program's version number and exit + -x X HELP + --y Y Y HELP + + EPILOG + ''' + version = '''\ + 0.1 + ''' + +class TestShortColumns(HelpTestCase): + '''Test extremely small number of columns. + + TestCase prevents "COLUMNS" from being too small in the tests themselves, + but we don't want any exceptions thrown in such cases. Only ugly representation. + ''' + def setUp(self): + env = support.EnvironmentVarGuard() + env.set("COLUMNS", '15') + self.addCleanup(env.__exit__) + + parser_signature = TestHelpBiggerOptionals.parser_signature + argument_signatures = TestHelpBiggerOptionals.argument_signatures + argument_group_signatures = TestHelpBiggerOptionals.argument_group_signatures + usage = '''\ + usage: PROG + [-h] + [-v] + [-x] + [--y Y] + foo + bar + ''' + help = usage + '''\ + + DESCRIPTION + + positional arguments: + foo + FOO HELP + bar + BAR HELP + + optional arguments: + -h, --help + show this + help + message and + exit + -v, --version + show + program's + version + number and + exit + -x + X HELP + --y Y + Y HELP + + EPILOG + ''' + version = TestHelpBiggerOptionals.version + + +class TestHelpBiggerOptionalGroups(HelpTestCase): + """Make sure that argument help aligns when options are longer""" + + parser_signature = Sig(prog='PROG', description='DESCRIPTION', + epilog='EPILOG') + argument_signatures = [ + Sig('-v', '--version', action='version', version='0.1'), + Sig('-x', action='store_true', help='X HELP'), + Sig('--y', help='Y HELP'), + Sig('foo', help='FOO HELP'), + Sig('bar', help='BAR HELP'), + ] + argument_group_signatures = [ + (Sig('GROUP TITLE', description='GROUP DESCRIPTION'), [ + Sig('baz', help='BAZ HELP'), + Sig('-z', nargs='+', help='Z HELP')]), + ] + usage = '''\ + usage: PROG [-h] [-v] [-x] [--y Y] [-z Z [Z ...]] foo bar baz + ''' + help = usage + '''\ + + DESCRIPTION + + positional arguments: + foo FOO HELP + bar BAR HELP + + optional arguments: + -h, --help show this help message and exit + -v, --version show program's version number and exit + -x X HELP + --y Y Y HELP + + GROUP TITLE: + GROUP DESCRIPTION + + baz BAZ HELP + -z Z [Z ...] Z HELP + + EPILOG + ''' + version = '''\ + 0.1 + ''' + + +class TestHelpBiggerPositionals(HelpTestCase): + """Make sure that help aligns when arguments are longer""" + + parser_signature = Sig(usage='USAGE', description='DESCRIPTION') + argument_signatures = [ + Sig('-x', action='store_true', help='X HELP'), + Sig('--y', help='Y HELP'), + Sig('ekiekiekifekang', help='EKI HELP'), + Sig('bar', help='BAR HELP'), + ] + argument_group_signatures = [] + usage = '''\ + usage: USAGE + ''' + help = usage + '''\ + + DESCRIPTION + + positional arguments: + ekiekiekifekang EKI HELP + bar BAR HELP + + optional arguments: + -h, --help show this help message and exit + -x X HELP + --y Y Y HELP + ''' + + version = '' + + +class TestHelpReformatting(HelpTestCase): + """Make sure that text after short names starts on the first line""" + + parser_signature = Sig( + prog='PROG', + description=' oddly formatted\n' + 'description\n' + '\n' + 'that is so long that it should go onto multiple ' + 'lines when wrapped') + argument_signatures = [ + Sig('-x', metavar='XX', help='oddly\n' + ' formatted -x help'), + Sig('y', metavar='yyy', help='normal y help'), + ] + argument_group_signatures = [ + (Sig('title', description='\n' + ' oddly formatted group\n' + '\n' + 'description'), + [Sig('-a', action='store_true', + help=' oddly \n' + 'formatted -a help \n' + ' again, so long that it should be wrapped over ' + 'multiple lines')]), + ] + usage = '''\ + usage: PROG [-h] [-x XX] [-a] yyy + ''' + help = usage + '''\ + + oddly formatted description that is so long that it should go onto \ +multiple + lines when wrapped + + positional arguments: + yyy normal y help + + optional arguments: + -h, --help show this help message and exit + -x XX oddly formatted -x help + + title: + oddly formatted group description + + -a oddly formatted -a help again, so long that it should \ +be wrapped + over multiple lines + ''' + version = '' + + +class TestHelpWrappingShortNames(HelpTestCase): + """Make sure that text after short names starts on the first line""" + + parser_signature = Sig(prog='PROG', description= 'D\nD' * 30) + argument_signatures = [ + Sig('-x', metavar='XX', help='XHH HX' * 20), + Sig('y', metavar='yyy', help='YH YH' * 20), + ] + argument_group_signatures = [ + (Sig('ALPHAS'), [ + Sig('-a', action='store_true', help='AHHH HHA' * 10)]), + ] + usage = '''\ + usage: PROG [-h] [-x XX] [-a] yyy + ''' + help = usage + '''\ + + D DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD \ +DD DD DD + DD DD DD DD D + + positional arguments: + yyy YH YHYH YHYH YHYH YHYH YHYH YHYH YHYH YHYH YHYH YHYH \ +YHYH YHYH + YHYH YHYH YHYH YHYH YHYH YHYH YHYH YH + + optional arguments: + -h, --help show this help message and exit + -x XX XHH HXXHH HXXHH HXXHH HXXHH HXXHH HXXHH HXXHH HXXHH \ +HXXHH HXXHH + HXXHH HXXHH HXXHH HXXHH HXXHH HXXHH HXXHH HXXHH HXXHH HX + + ALPHAS: + -a AHHH HHAAHHH HHAAHHH HHAAHHH HHAAHHH HHAAHHH HHAAHHH \ +HHAAHHH + HHAAHHH HHAAHHH HHA + ''' + version = '' + + +class TestHelpWrappingLongNames(HelpTestCase): + """Make sure that text after long names starts on the next line""" + + parser_signature = Sig(usage='USAGE', description= 'D D' * 30) + argument_signatures = [ + Sig('-v', '--version', action='version', version='V V' * 30), + Sig('-x', metavar='X' * 25, help='XH XH' * 20), + Sig('y', metavar='y' * 25, help='YH YH' * 20), + ] + argument_group_signatures = [ + (Sig('ALPHAS'), [ + Sig('-a', metavar='A' * 25, help='AH AH' * 20), + Sig('z', metavar='z' * 25, help='ZH ZH' * 20)]), + ] + usage = '''\ + usage: USAGE + ''' + help = usage + '''\ + + D DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD DD \ +DD DD DD + DD DD DD DD D + + positional arguments: + yyyyyyyyyyyyyyyyyyyyyyyyy + YH YHYH YHYH YHYH YHYH YHYH YHYH YHYH YHYH \ +YHYH YHYH + YHYH YHYH YHYH YHYH YHYH YHYH YHYH YHYH YHYH YH + + optional arguments: + -h, --help show this help message and exit + -v, --version show program's version number and exit + -x XXXXXXXXXXXXXXXXXXXXXXXXX + XH XHXH XHXH XHXH XHXH XHXH XHXH XHXH XHXH \ +XHXH XHXH + XHXH XHXH XHXH XHXH XHXH XHXH XHXH XHXH XHXH XH + + ALPHAS: + -a AAAAAAAAAAAAAAAAAAAAAAAAA + AH AHAH AHAH AHAH AHAH AHAH AHAH AHAH AHAH \ +AHAH AHAH + AHAH AHAH AHAH AHAH AHAH AHAH AHAH AHAH AHAH AH + zzzzzzzzzzzzzzzzzzzzzzzzz + ZH ZHZH ZHZH ZHZH ZHZH ZHZH ZHZH ZHZH ZHZH \ +ZHZH ZHZH + ZHZH ZHZH ZHZH ZHZH ZHZH ZHZH ZHZH ZHZH ZHZH ZH + ''' + version = '''\ + V VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV VV \ +VV VV VV + VV VV VV VV V + ''' + + +class TestHelpUsage(HelpTestCase): + """Test basic usage messages""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-w', nargs='+', help='w'), + Sig('-x', nargs='*', help='x'), + Sig('a', help='a'), + Sig('b', help='b', nargs=2), + Sig('c', help='c', nargs='?'), + ] + argument_group_signatures = [ + (Sig('group'), [ + Sig('-y', nargs='?', help='y'), + Sig('-z', nargs=3, help='z'), + Sig('d', help='d', nargs='*'), + Sig('e', help='e', nargs='+'), + ]) + ] + usage = '''\ + usage: PROG [-h] [-w W [W ...]] [-x [X [X ...]]] [-y [Y]] [-z Z Z Z] + a b b [c] [d [d ...]] e [e ...] + ''' + help = usage + '''\ + + positional arguments: + a a + b b + c c + + optional arguments: + -h, --help show this help message and exit + -w W [W ...] w + -x [X [X ...]] x + + group: + -y [Y] y + -z Z Z Z z + d d + e e + ''' + version = '' + + +class TestHelpOnlyUserGroups(HelpTestCase): + """Test basic usage messages""" + + parser_signature = Sig(prog='PROG', add_help=False) + argument_signatures = [] + argument_group_signatures = [ + (Sig('xxxx'), [ + Sig('-x', help='x'), + Sig('a', help='a'), + ]), + (Sig('yyyy'), [ + Sig('b', help='b'), + Sig('-y', help='y'), + ]), + ] + usage = '''\ + usage: PROG [-x X] [-y Y] a b + ''' + help = usage + '''\ + + xxxx: + -x X x + a a + + yyyy: + b b + -y Y y + ''' + version = '' + + +class TestHelpUsageLongProg(HelpTestCase): + """Test usage messages where the prog is long""" + + parser_signature = Sig(prog='P' * 60) + argument_signatures = [ + Sig('-w', metavar='W'), + Sig('-x', metavar='X'), + Sig('a'), + Sig('b'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP + [-h] [-w W] [-x X] a b + ''' + help = usage + '''\ + + positional arguments: + a + b + + optional arguments: + -h, --help show this help message and exit + -w W + -x X + ''' + version = '' + + +class TestHelpUsageLongProgOptionsWrap(HelpTestCase): + """Test usage messages where the prog is long and the optionals wrap""" + + parser_signature = Sig(prog='P' * 60) + argument_signatures = [ + Sig('-w', metavar='W' * 25), + Sig('-x', metavar='X' * 25), + Sig('-y', metavar='Y' * 25), + Sig('-z', metavar='Z' * 25), + Sig('a'), + Sig('b'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP + [-h] [-w WWWWWWWWWWWWWWWWWWWWWWWWW] \ +[-x XXXXXXXXXXXXXXXXXXXXXXXXX] + [-y YYYYYYYYYYYYYYYYYYYYYYYYY] [-z ZZZZZZZZZZZZZZZZZZZZZZZZZ] + a b + ''' + help = usage + '''\ + + positional arguments: + a + b + + optional arguments: + -h, --help show this help message and exit + -w WWWWWWWWWWWWWWWWWWWWWWWWW + -x XXXXXXXXXXXXXXXXXXXXXXXXX + -y YYYYYYYYYYYYYYYYYYYYYYYYY + -z ZZZZZZZZZZZZZZZZZZZZZZZZZ + ''' + version = '' + + +class TestHelpUsageLongProgPositionalsWrap(HelpTestCase): + """Test usage messages where the prog is long and the positionals wrap""" + + parser_signature = Sig(prog='P' * 60, add_help=False) + argument_signatures = [ + Sig('a' * 25), + Sig('b' * 25), + Sig('c' * 25), + ] + argument_group_signatures = [] + usage = '''\ + usage: PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP + aaaaaaaaaaaaaaaaaaaaaaaaa bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + ''' + help = usage + '''\ + + positional arguments: + aaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + ''' + version = '' + + +class TestHelpUsageOptionalsWrap(HelpTestCase): + """Test usage messages where the optionals wrap""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-w', metavar='W' * 25), + Sig('-x', metavar='X' * 25), + Sig('-y', metavar='Y' * 25), + Sig('-z', metavar='Z' * 25), + Sig('a'), + Sig('b'), + Sig('c'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-w WWWWWWWWWWWWWWWWWWWWWWWWW] \ +[-x XXXXXXXXXXXXXXXXXXXXXXXXX] + [-y YYYYYYYYYYYYYYYYYYYYYYYYY] \ +[-z ZZZZZZZZZZZZZZZZZZZZZZZZZ] + a b c + ''' + help = usage + '''\ + + positional arguments: + a + b + c + + optional arguments: + -h, --help show this help message and exit + -w WWWWWWWWWWWWWWWWWWWWWWWWW + -x XXXXXXXXXXXXXXXXXXXXXXXXX + -y YYYYYYYYYYYYYYYYYYYYYYYYY + -z ZZZZZZZZZZZZZZZZZZZZZZZZZ + ''' + version = '' + + +class TestHelpUsagePositionalsWrap(HelpTestCase): + """Test usage messages where the positionals wrap""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-x'), + Sig('-y'), + Sig('-z'), + Sig('a' * 25), + Sig('b' * 25), + Sig('c' * 25), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-x X] [-y Y] [-z Z] + aaaaaaaaaaaaaaaaaaaaaaaaa bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + ''' + help = usage + '''\ + + positional arguments: + aaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + + optional arguments: + -h, --help show this help message and exit + -x X + -y Y + -z Z + ''' + version = '' + + +class TestHelpUsageOptionalsPositionalsWrap(HelpTestCase): + """Test usage messages where the optionals and positionals wrap""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-x', metavar='X' * 25), + Sig('-y', metavar='Y' * 25), + Sig('-z', metavar='Z' * 25), + Sig('a' * 25), + Sig('b' * 25), + Sig('c' * 25), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-x XXXXXXXXXXXXXXXXXXXXXXXXX] \ +[-y YYYYYYYYYYYYYYYYYYYYYYYYY] + [-z ZZZZZZZZZZZZZZZZZZZZZZZZZ] + aaaaaaaaaaaaaaaaaaaaaaaaa bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + ''' + help = usage + '''\ + + positional arguments: + aaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + + optional arguments: + -h, --help show this help message and exit + -x XXXXXXXXXXXXXXXXXXXXXXXXX + -y YYYYYYYYYYYYYYYYYYYYYYYYY + -z ZZZZZZZZZZZZZZZZZZZZZZZZZ + ''' + version = '' + + +class TestHelpUsageOptionalsOnlyWrap(HelpTestCase): + """Test usage messages where there are only optionals and they wrap""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-x', metavar='X' * 25), + Sig('-y', metavar='Y' * 25), + Sig('-z', metavar='Z' * 25), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-x XXXXXXXXXXXXXXXXXXXXXXXXX] \ +[-y YYYYYYYYYYYYYYYYYYYYYYYYY] + [-z ZZZZZZZZZZZZZZZZZZZZZZZZZ] + ''' + help = usage + '''\ + + optional arguments: + -h, --help show this help message and exit + -x XXXXXXXXXXXXXXXXXXXXXXXXX + -y YYYYYYYYYYYYYYYYYYYYYYYYY + -z ZZZZZZZZZZZZZZZZZZZZZZZZZ + ''' + version = '' + + +class TestHelpUsagePositionalsOnlyWrap(HelpTestCase): + """Test usage messages where there are only positionals and they wrap""" + + parser_signature = Sig(prog='PROG', add_help=False) + argument_signatures = [ + Sig('a' * 25), + Sig('b' * 25), + Sig('c' * 25), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG aaaaaaaaaaaaaaaaaaaaaaaaa bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + ''' + help = usage + '''\ + + positional arguments: + aaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbb + ccccccccccccccccccccccccc + ''' + version = '' + + +class TestHelpVariableExpansion(HelpTestCase): + """Test that variables are expanded properly in help messages""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-x', type=int, + help='x %(prog)s %(default)s %(type)s %%'), + Sig('-y', action='store_const', default=42, const='XXX', + help='y %(prog)s %(default)s %(const)s'), + Sig('--foo', choices='abc', + help='foo %(prog)s %(default)s %(choices)s'), + Sig('--bar', default='baz', choices=[1, 2], metavar='BBB', + help='bar %(prog)s %(default)s %(dest)s'), + Sig('spam', help='spam %(prog)s %(default)s'), + Sig('badger', default=0.5, help='badger %(prog)s %(default)s'), + ] + argument_group_signatures = [ + (Sig('group'), [ + Sig('-a', help='a %(prog)s %(default)s'), + Sig('-b', default=-1, help='b %(prog)s %(default)s'), + ]) + ] + usage = ('''\ + usage: PROG [-h] [-x X] [-y] [--foo {a,b,c}] [--bar BBB] [-a A] [-b B] + spam badger + ''') + help = usage + '''\ + + positional arguments: + spam spam PROG None + badger badger PROG 0.5 + + optional arguments: + -h, --help show this help message and exit + -x X x PROG None int % + -y y PROG 42 XXX + --foo {a,b,c} foo PROG None a, b, c + --bar BBB bar PROG baz bar + + group: + -a A a PROG None + -b B b PROG -1 + ''' + version = '' + + +class TestHelpVariableExpansionUsageSupplied(HelpTestCase): + """Test that variables are expanded properly when usage= is present""" + + parser_signature = Sig(prog='PROG', usage='%(prog)s FOO') + argument_signatures = [] + argument_group_signatures = [] + usage = ('''\ + usage: PROG FOO + ''') + help = usage + '''\ + + optional arguments: + -h, --help show this help message and exit + ''' + version = '' + + +class TestHelpVariableExpansionNoArguments(HelpTestCase): + """Test that variables are expanded properly with no arguments""" + + parser_signature = Sig(prog='PROG', add_help=False) + argument_signatures = [] + argument_group_signatures = [] + usage = ('''\ + usage: PROG + ''') + help = usage + version = '' + + +class TestHelpSuppressUsage(HelpTestCase): + """Test that items can be suppressed in usage messages""" + + parser_signature = Sig(prog='PROG', usage=argparse.SUPPRESS) + argument_signatures = [ + Sig('--foo', help='foo help'), + Sig('spam', help='spam help'), + ] + argument_group_signatures = [] + help = '''\ + positional arguments: + spam spam help + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help + ''' + usage = '' + version = '' + + +class TestHelpSuppressOptional(HelpTestCase): + """Test that optional arguments can be suppressed in help messages""" + + parser_signature = Sig(prog='PROG', add_help=False) + argument_signatures = [ + Sig('--foo', help=argparse.SUPPRESS), + Sig('spam', help='spam help'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG spam + ''' + help = usage + '''\ + + positional arguments: + spam spam help + ''' + version = '' + + +class TestHelpSuppressOptionalGroup(HelpTestCase): + """Test that optional groups can be suppressed in help messages""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('--foo', help='foo help'), + Sig('spam', help='spam help'), + ] + argument_group_signatures = [ + (Sig('group'), [Sig('--bar', help=argparse.SUPPRESS)]), + ] + usage = '''\ + usage: PROG [-h] [--foo FOO] spam + ''' + help = usage + '''\ + + positional arguments: + spam spam help + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help + ''' + version = '' + + +class TestHelpSuppressPositional(HelpTestCase): + """Test that positional arguments can be suppressed in help messages""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('--foo', help='foo help'), + Sig('spam', help=argparse.SUPPRESS), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [--foo FOO] + ''' + help = usage + '''\ + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help + ''' + version = '' + + +class TestHelpRequiredOptional(HelpTestCase): + """Test that required options don't look optional""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('--foo', required=True, help='foo help'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] --foo FOO + ''' + help = usage + '''\ + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help + ''' + version = '' + + +class TestHelpAlternatePrefixChars(HelpTestCase): + """Test that options display with different prefix characters""" + + parser_signature = Sig(prog='PROG', prefix_chars='^;', add_help=False) + argument_signatures = [ + Sig('^^foo', action='store_true', help='foo help'), + Sig(';b', ';;bar', help='bar help'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [^^foo] [;b BAR] + ''' + help = usage + '''\ + + optional arguments: + ^^foo foo help + ;b BAR, ;;bar BAR bar help + ''' + version = '' + + +class TestHelpNoHelpOptional(HelpTestCase): + """Test that the --help argument can be suppressed help messages""" + + parser_signature = Sig(prog='PROG', add_help=False) + argument_signatures = [ + Sig('--foo', help='foo help'), + Sig('spam', help='spam help'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [--foo FOO] spam + ''' + help = usage + '''\ + + positional arguments: + spam spam help + + optional arguments: + --foo FOO foo help + ''' + version = '' + + +class TestHelpNone(HelpTestCase): + """Test that no errors occur if no help is specified""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('--foo'), + Sig('spam'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [--foo FOO] spam + ''' + help = usage + '''\ + + positional arguments: + spam + + optional arguments: + -h, --help show this help message and exit + --foo FOO + ''' + version = '' + + +class TestHelpTupleMetavar(HelpTestCase): + """Test specifying metavar as a tuple""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-w', help='w', nargs='+', metavar=('W1', 'W2')), + Sig('-x', help='x', nargs='*', metavar=('X1', 'X2')), + Sig('-y', help='y', nargs=3, metavar=('Y1', 'Y2', 'Y3')), + Sig('-z', help='z', nargs='?', metavar=('Z1', )), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-w W1 [W2 ...]] [-x [X1 [X2 ...]]] [-y Y1 Y2 Y3] \ +[-z [Z1]] + ''' + help = usage + '''\ + + optional arguments: + -h, --help show this help message and exit + -w W1 [W2 ...] w + -x [X1 [X2 ...]] x + -y Y1 Y2 Y3 y + -z [Z1] z + ''' + version = '' + + +class TestHelpRawText(HelpTestCase): + """Test the RawTextHelpFormatter""" + + parser_signature = Sig( + prog='PROG', formatter_class=argparse.RawTextHelpFormatter, + description='Keep the formatting\n' + ' exactly as it is written\n' + '\n' + 'here\n') + + argument_signatures = [ + Sig('--foo', help=' foo help should also\n' + 'appear as given here'), + Sig('spam', help='spam help'), + ] + argument_group_signatures = [ + (Sig('title', description=' This text\n' + ' should be indented\n' + ' exactly like it is here\n'), + [Sig('--bar', help='bar help')]), + ] + usage = '''\ + usage: PROG [-h] [--foo FOO] [--bar BAR] spam + ''' + help = usage + '''\ + + Keep the formatting + exactly as it is written + + here + + positional arguments: + spam spam help + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help should also + appear as given here + + title: + This text + should be indented + exactly like it is here + + --bar BAR bar help + ''' + version = '' + + +class TestHelpRawDescription(HelpTestCase): + """Test the RawTextHelpFormatter""" + + parser_signature = Sig( + prog='PROG', formatter_class=argparse.RawDescriptionHelpFormatter, + description='Keep the formatting\n' + ' exactly as it is written\n' + '\n' + 'here\n') + + argument_signatures = [ + Sig('--foo', help=' foo help should not\n' + ' retain this odd formatting'), + Sig('spam', help='spam help'), + ] + argument_group_signatures = [ + (Sig('title', description=' This text\n' + ' should be indented\n' + ' exactly like it is here\n'), + [Sig('--bar', help='bar help')]), + ] + usage = '''\ + usage: PROG [-h] [--foo FOO] [--bar BAR] spam + ''' + help = usage + '''\ + + Keep the formatting + exactly as it is written + + here + + positional arguments: + spam spam help + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help should not retain this odd formatting + + title: + This text + should be indented + exactly like it is here + + --bar BAR bar help + ''' + version = '' + + +class TestHelpArgumentDefaults(HelpTestCase): + """Test the ArgumentDefaultsHelpFormatter""" + + parser_signature = Sig( + prog='PROG', formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description='description') + + argument_signatures = [ + Sig('--foo', help='foo help - oh and by the way, %(default)s'), + Sig('--bar', action='store_true', help='bar help'), + Sig('spam', help='spam help'), + Sig('badger', nargs='?', default='wooden', help='badger help'), + ] + argument_group_signatures = [ + (Sig('title', description='description'), + [Sig('--baz', type=int, default=42, help='baz help')]), + ] + usage = '''\ + usage: PROG [-h] [--foo FOO] [--bar] [--baz BAZ] spam [badger] + ''' + help = usage + '''\ + + description + + positional arguments: + spam spam help + badger badger help (default: wooden) + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help - oh and by the way, None + --bar bar help (default: False) + + title: + description + + --baz BAZ baz help (default: 42) + ''' + version = '' + +class TestHelpVersionAction(HelpTestCase): + """Test the default help for the version action""" + + parser_signature = Sig(prog='PROG', description='description') + argument_signatures = [Sig('-V', '--version', action='version', version='3.6')] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-V] + ''' + help = usage + '''\ + + description + + optional arguments: + -h, --help show this help message and exit + -V, --version show program's version number and exit + ''' + version = '' + + +class TestHelpVersionActionSuppress(HelpTestCase): + """Test that the --version argument can be suppressed in help messages""" + + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('-v', '--version', action='version', version='1.0', + help=argparse.SUPPRESS), + Sig('--foo', help='foo help'), + Sig('spam', help='spam help'), + ] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [--foo FOO] spam + ''' + help = usage + '''\ + + positional arguments: + spam spam help + + optional arguments: + -h, --help show this help message and exit + --foo FOO foo help + ''' + + +class TestHelpSubparsersOrdering(HelpTestCase): + """Test ordering of subcommands in help matches the code""" + parser_signature = Sig(prog='PROG', + description='display some subcommands') + argument_signatures = [Sig('-v', '--version', action='version', version='0.1')] + + subparsers_signatures = [Sig(name=name) + for name in ('a', 'b', 'c', 'd', 'e')] + + usage = '''\ + usage: PROG [-h] [-v] {a,b,c,d,e} ... + ''' + + help = usage + '''\ + + display some subcommands + + positional arguments: + {a,b,c,d,e} + + optional arguments: + -h, --help show this help message and exit + -v, --version show program's version number and exit + ''' + + version = '''\ + 0.1 + ''' + +class TestHelpSubparsersWithHelpOrdering(HelpTestCase): + """Test ordering of subcommands in help matches the code""" + parser_signature = Sig(prog='PROG', + description='display some subcommands') + argument_signatures = [Sig('-v', '--version', action='version', version='0.1')] + + subcommand_data = (('a', 'a subcommand help'), + ('b', 'b subcommand help'), + ('c', 'c subcommand help'), + ('d', 'd subcommand help'), + ('e', 'e subcommand help'), + ) + + subparsers_signatures = [Sig(name=name, help=help) + for name, help in subcommand_data] + + usage = '''\ + usage: PROG [-h] [-v] {a,b,c,d,e} ... + ''' + + help = usage + '''\ + + display some subcommands + + positional arguments: + {a,b,c,d,e} + a a subcommand help + b b subcommand help + c c subcommand help + d d subcommand help + e e subcommand help + + optional arguments: + -h, --help show this help message and exit + -v, --version show program's version number and exit + ''' + + version = '''\ + 0.1 + ''' + + + +class TestHelpMetavarTypeFormatter(HelpTestCase): + """""" + + def custom_type(string): + return string + + parser_signature = Sig(prog='PROG', description='description', + formatter_class=argparse.MetavarTypeHelpFormatter) + argument_signatures = [Sig('a', type=int), + Sig('-b', type=custom_type), + Sig('-c', type=float, metavar='SOME FLOAT')] + argument_group_signatures = [] + usage = '''\ + usage: PROG [-h] [-b custom_type] [-c SOME FLOAT] int + ''' + help = usage + '''\ + + description + + positional arguments: + int + + optional arguments: + -h, --help show this help message and exit + -b custom_type + -c SOME FLOAT + ''' + version = '' + + +# ===================================== +# Optional/Positional constructor tests +# ===================================== + +class TestInvalidArgumentConstructors(TestCase): + """Test a bunch of invalid Argument constructors""" + + def assertTypeError(self, *args, **kwargs): + parser = argparse.ArgumentParser() + self.assertRaises(TypeError, parser.add_argument, + *args, **kwargs) + + def assertValueError(self, *args, **kwargs): + parser = argparse.ArgumentParser() + self.assertRaises(ValueError, parser.add_argument, + *args, **kwargs) + + def test_invalid_keyword_arguments(self): + self.assertTypeError('-x', bar=None) + self.assertTypeError('-y', callback='foo') + self.assertTypeError('-y', callback_args=()) + self.assertTypeError('-y', callback_kwargs={}) + + def test_missing_destination(self): + self.assertTypeError() + for action in ['append', 'store']: + self.assertTypeError(action=action) + + def test_invalid_option_strings(self): + self.assertValueError('--') + self.assertValueError('---') + + def test_invalid_type(self): + self.assertValueError('--foo', type='int') + self.assertValueError('--foo', type=(int, float)) + + def test_invalid_action(self): + self.assertValueError('-x', action='foo') + self.assertValueError('foo', action='baz') + self.assertValueError('--foo', action=('store', 'append')) + parser = argparse.ArgumentParser() + with self.assertRaises(ValueError) as cm: + parser.add_argument("--foo", action="store-true") + self.assertIn('unknown action', str(cm.exception)) + + def test_multiple_dest(self): + parser = argparse.ArgumentParser() + parser.add_argument(dest='foo') + with self.assertRaises(ValueError) as cm: + parser.add_argument('bar', dest='baz') + self.assertIn('dest supplied twice for positional argument', + str(cm.exception)) + + def test_no_argument_actions(self): + for action in ['store_const', 'store_true', 'store_false', + 'append_const', 'count']: + for attrs in [dict(type=int), dict(nargs='+'), + dict(choices='ab')]: + self.assertTypeError('-x', action=action, **attrs) + + def test_no_argument_no_const_actions(self): + # options with zero arguments + for action in ['store_true', 'store_false', 'count']: + + # const is always disallowed + self.assertTypeError('-x', const='foo', action=action) + + # nargs is always disallowed + self.assertTypeError('-x', nargs='*', action=action) + + def test_more_than_one_argument_actions(self): + for action in ['store', 'append']: + + # nargs=0 is disallowed + self.assertValueError('-x', nargs=0, action=action) + self.assertValueError('spam', nargs=0, action=action) + + # const is disallowed with non-optional arguments + for nargs in [1, '*', '+']: + self.assertValueError('-x', const='foo', + nargs=nargs, action=action) + self.assertValueError('spam', const='foo', + nargs=nargs, action=action) + + def test_required_const_actions(self): + for action in ['store_const', 'append_const']: + + # nargs is always disallowed + self.assertTypeError('-x', nargs='+', action=action) + + def test_parsers_action_missing_params(self): + self.assertTypeError('command', action='parsers') + self.assertTypeError('command', action='parsers', prog='PROG') + self.assertTypeError('command', action='parsers', + parser_class=argparse.ArgumentParser) + + def test_required_positional(self): + self.assertTypeError('foo', required=True) + + def test_user_defined_action(self): + + class Success(Exception): + pass + + class Action(object): + + def __init__(self, + option_strings, + dest, + const, + default, + required=False): + if dest == 'spam': + if const is Success: + if default is Success: + raise Success() + + def __call__(self, *args, **kwargs): + pass + + parser = argparse.ArgumentParser() + self.assertRaises(Success, parser.add_argument, '--spam', + action=Action, default=Success, const=Success) + self.assertRaises(Success, parser.add_argument, 'spam', + action=Action, default=Success, const=Success) + +# ================================ +# Actions returned by add_argument +# ================================ + +class TestActionsReturned(TestCase): + + def test_dest(self): + parser = argparse.ArgumentParser() + action = parser.add_argument('--foo') + self.assertEqual(action.dest, 'foo') + action = parser.add_argument('-b', '--bar') + self.assertEqual(action.dest, 'bar') + action = parser.add_argument('-x', '-y') + self.assertEqual(action.dest, 'x') + + def test_misc(self): + parser = argparse.ArgumentParser() + action = parser.add_argument('--foo', nargs='?', const=42, + default=84, type=int, choices=[1, 2], + help='FOO', metavar='BAR', dest='baz') + self.assertEqual(action.nargs, '?') + self.assertEqual(action.const, 42) + self.assertEqual(action.default, 84) + self.assertEqual(action.type, int) + self.assertEqual(action.choices, [1, 2]) + self.assertEqual(action.help, 'FOO') + self.assertEqual(action.metavar, 'BAR') + self.assertEqual(action.dest, 'baz') + + +# ================================ +# Argument conflict handling tests +# ================================ + +class TestConflictHandling(TestCase): + + def test_bad_type(self): + self.assertRaises(ValueError, argparse.ArgumentParser, + conflict_handler='foo') + + def test_conflict_error(self): + parser = argparse.ArgumentParser() + parser.add_argument('-x') + self.assertRaises(argparse.ArgumentError, + parser.add_argument, '-x') + parser.add_argument('--spam') + self.assertRaises(argparse.ArgumentError, + parser.add_argument, '--spam') + + def test_resolve_error(self): + get_parser = argparse.ArgumentParser + parser = get_parser(prog='PROG', conflict_handler='resolve') + + parser.add_argument('-x', help='OLD X') + parser.add_argument('-x', help='NEW X') + self.assertEqual(parser.format_help(), textwrap.dedent('''\ + usage: PROG [-h] [-x X] + + optional arguments: + -h, --help show this help message and exit + -x X NEW X + ''')) + + parser.add_argument('--spam', metavar='OLD_SPAM') + parser.add_argument('--spam', metavar='NEW_SPAM') + self.assertEqual(parser.format_help(), textwrap.dedent('''\ + usage: PROG [-h] [-x X] [--spam NEW_SPAM] + + optional arguments: + -h, --help show this help message and exit + -x X NEW X + --spam NEW_SPAM + ''')) + + +# ============================= +# Help and Version option tests +# ============================= + +class TestOptionalsHelpVersionActions(TestCase): + """Test the help and version actions""" + + def assertPrintHelpExit(self, parser, args_str): + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(args_str.split()) + self.assertEqual(parser.format_help(), cm.exception.stdout) + + def assertArgumentParserError(self, parser, *args): + self.assertRaises(ArgumentParserError, parser.parse_args, args) + + def test_version(self): + parser = ErrorRaisingArgumentParser() + parser.add_argument('-v', '--version', action='version', version='1.0') + self.assertPrintHelpExit(parser, '-h') + self.assertPrintHelpExit(parser, '--help') + self.assertRaises(AttributeError, getattr, parser, 'format_version') + + def test_version_format(self): + parser = ErrorRaisingArgumentParser(prog='PPP') + parser.add_argument('-v', '--version', action='version', version='%(prog)s 3.5') + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['-v']) + self.assertEqual('PPP 3.5\n', cm.exception.stdout) + + def test_version_no_help(self): + parser = ErrorRaisingArgumentParser(add_help=False) + parser.add_argument('-v', '--version', action='version', version='1.0') + self.assertArgumentParserError(parser, '-h') + self.assertArgumentParserError(parser, '--help') + self.assertRaises(AttributeError, getattr, parser, 'format_version') + + def test_version_action(self): + parser = ErrorRaisingArgumentParser(prog='XXX') + parser.add_argument('-V', action='version', version='%(prog)s 3.7') + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['-V']) + self.assertEqual('XXX 3.7\n', cm.exception.stdout) + + def test_no_help(self): + parser = ErrorRaisingArgumentParser(add_help=False) + self.assertArgumentParserError(parser, '-h') + self.assertArgumentParserError(parser, '--help') + self.assertArgumentParserError(parser, '-v') + self.assertArgumentParserError(parser, '--version') + + def test_alternate_help_version(self): + parser = ErrorRaisingArgumentParser() + parser.add_argument('-x', action='help') + parser.add_argument('-y', action='version') + self.assertPrintHelpExit(parser, '-x') + self.assertArgumentParserError(parser, '-v') + self.assertArgumentParserError(parser, '--version') + self.assertRaises(AttributeError, getattr, parser, 'format_version') + + def test_help_version_extra_arguments(self): + parser = ErrorRaisingArgumentParser() + parser.add_argument('--version', action='version', version='1.0') + parser.add_argument('-x', action='store_true') + parser.add_argument('y') + + # try all combinations of valid prefixes and suffixes + valid_prefixes = ['', '-x', 'foo', '-x bar', 'baz -x'] + valid_suffixes = valid_prefixes + ['--bad-option', 'foo bar baz'] + for prefix in valid_prefixes: + for suffix in valid_suffixes: + format = '%s %%s %s' % (prefix, suffix) + self.assertPrintHelpExit(parser, format % '-h') + self.assertPrintHelpExit(parser, format % '--help') + self.assertRaises(AttributeError, getattr, parser, 'format_version') + + +# ====================== +# str() and repr() tests +# ====================== + +class TestStrings(TestCase): + """Test str() and repr() on Optionals and Positionals""" + + def assertStringEqual(self, obj, result_string): + for func in [str, repr]: + self.assertEqual(func(obj), result_string) + + def test_optional(self): + option = argparse.Action( + option_strings=['--foo', '-a', '-b'], + dest='b', + type='int', + nargs='+', + default=42, + choices=[1, 2, 3], + help='HELP', + metavar='METAVAR') + string = ( + "Action(option_strings=['--foo', '-a', '-b'], dest='b', " + "nargs='+', const=None, default=42, type='int', " + "choices=[1, 2, 3], help='HELP', metavar='METAVAR')") + self.assertStringEqual(option, string) + + def test_argument(self): + argument = argparse.Action( + option_strings=[], + dest='x', + type=float, + nargs='?', + default=2.5, + choices=[0.5, 1.5, 2.5], + help='H HH H', + metavar='MV MV MV') + string = ( + "Action(option_strings=[], dest='x', nargs='?', " + "const=None, default=2.5, type=%r, choices=[0.5, 1.5, 2.5], " + "help='H HH H', metavar='MV MV MV')" % float) + self.assertStringEqual(argument, string) + + def test_namespace(self): + ns = argparse.Namespace(foo=42, bar='spam') + string = "Namespace(bar='spam', foo=42)" + self.assertStringEqual(ns, string) + + def test_namespace_starkwargs_notidentifier(self): + ns = argparse.Namespace(**{'"': 'quote'}) + string = """Namespace(**{'"': 'quote'})""" + self.assertStringEqual(ns, string) + + def test_namespace_kwargs_and_starkwargs_notidentifier(self): + ns = argparse.Namespace(a=1, **{'"': 'quote'}) + string = """Namespace(a=1, **{'"': 'quote'})""" + self.assertStringEqual(ns, string) + + def test_namespace_starkwargs_identifier(self): + ns = argparse.Namespace(**{'valid': True}) + string = "Namespace(valid=True)" + self.assertStringEqual(ns, string) + + def test_parser(self): + parser = argparse.ArgumentParser(prog='PROG') + string = ( + "ArgumentParser(prog='PROG', usage=None, description=None, " + "formatter_class=%r, conflict_handler='error', " + "add_help=True)" % argparse.HelpFormatter) + self.assertStringEqual(parser, string) + +# =============== +# Namespace tests +# =============== + +class TestNamespace(TestCase): + + def test_constructor(self): + ns = argparse.Namespace() + self.assertRaises(AttributeError, getattr, ns, 'x') + + ns = argparse.Namespace(a=42, b='spam') + self.assertEqual(ns.a, 42) + self.assertEqual(ns.b, 'spam') + + def test_equality(self): + ns1 = argparse.Namespace(a=1, b=2) + ns2 = argparse.Namespace(b=2, a=1) + ns3 = argparse.Namespace(a=1) + ns4 = argparse.Namespace(b=2) + + self.assertEqual(ns1, ns2) + self.assertNotEqual(ns1, ns3) + self.assertNotEqual(ns1, ns4) + self.assertNotEqual(ns2, ns3) + self.assertNotEqual(ns2, ns4) + self.assertTrue(ns1 != ns3) + self.assertTrue(ns1 != ns4) + self.assertTrue(ns2 != ns3) + self.assertTrue(ns2 != ns4) + + def test_equality_returns_notimplemented(self): + # See issue 21481 + ns = argparse.Namespace(a=1, b=2) + self.assertIs(ns.__eq__(None), NotImplemented) + self.assertIs(ns.__ne__(None), NotImplemented) + + +# =================== +# File encoding tests +# =================== + +class TestEncoding(TestCase): + + def _test_module_encoding(self, path): + path, _ = os.path.splitext(path) + path += ".py" + with open(path, 'r', encoding='utf-8') as f: + f.read() + + def test_argparse_module_encoding(self): + self._test_module_encoding(argparse.__file__) + + def test_test_argparse_module_encoding(self): + self._test_module_encoding(__file__) + +# =================== +# ArgumentError tests +# =================== + +class TestArgumentError(TestCase): + + def test_argument_error(self): + msg = "my error here" + error = argparse.ArgumentError(None, msg) + self.assertEqual(str(error), msg) + +# ======================= +# ArgumentTypeError tests +# ======================= + +class TestArgumentTypeError(TestCase): + + def test_argument_type_error(self): + + def spam(string): + raise argparse.ArgumentTypeError('spam!') + + parser = ErrorRaisingArgumentParser(prog='PROG', add_help=False) + parser.add_argument('x', type=spam) + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['XXX']) + self.assertEqual('usage: PROG x\nPROG: error: argument x: spam!\n', + cm.exception.stderr) + +# ========================= +# MessageContentError tests +# ========================= + +class TestMessageContentError(TestCase): + + def test_missing_argument_name_in_message(self): + parser = ErrorRaisingArgumentParser(prog='PROG', usage='') + parser.add_argument('req_pos', type=str) + parser.add_argument('-req_opt', type=int, required=True) + parser.add_argument('need_one', type=str, nargs='+') + + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args([]) + msg = str(cm.exception) + self.assertRegex(msg, 'req_pos') + self.assertRegex(msg, 'req_opt') + self.assertRegex(msg, 'need_one') + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['myXargument']) + msg = str(cm.exception) + self.assertNotIn(msg, 'req_pos') + self.assertRegex(msg, 'req_opt') + self.assertRegex(msg, 'need_one') + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['myXargument', '-req_opt=1']) + msg = str(cm.exception) + self.assertNotIn(msg, 'req_pos') + self.assertNotIn(msg, 'req_opt') + self.assertRegex(msg, 'need_one') + + def test_optional_optional_not_in_message(self): + parser = ErrorRaisingArgumentParser(prog='PROG', usage='') + parser.add_argument('req_pos', type=str) + parser.add_argument('--req_opt', type=int, required=True) + parser.add_argument('--opt_opt', type=bool, nargs='?', + default=True) + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args([]) + msg = str(cm.exception) + self.assertRegex(msg, 'req_pos') + self.assertRegex(msg, 'req_opt') + self.assertNotIn(msg, 'opt_opt') + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['--req_opt=1']) + msg = str(cm.exception) + self.assertRegex(msg, 'req_pos') + self.assertNotIn(msg, 'req_opt') + self.assertNotIn(msg, 'opt_opt') + + def test_optional_positional_not_in_message(self): + parser = ErrorRaisingArgumentParser(prog='PROG', usage='') + parser.add_argument('req_pos') + parser.add_argument('optional_positional', nargs='?', default='eggs') + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args([]) + msg = str(cm.exception) + self.assertRegex(msg, 'req_pos') + self.assertNotIn(msg, 'optional_positional') + + +# ================================================ +# Check that the type function is called only once +# ================================================ + +class TestTypeFunctionCallOnlyOnce(TestCase): + + def test_type_function_call_only_once(self): + def spam(string_to_convert): + self.assertEqual(string_to_convert, 'spam!') + return 'foo_converted' + + parser = argparse.ArgumentParser() + parser.add_argument('--foo', type=spam, default='bar') + args = parser.parse_args('--foo spam!'.split()) + self.assertEqual(NS(foo='foo_converted'), args) + +# ================================================================== +# Check semantics regarding the default argument and type conversion +# ================================================================== + +class TestTypeFunctionCalledOnDefault(TestCase): + + def test_type_function_call_with_non_string_default(self): + def spam(int_to_convert): + self.assertEqual(int_to_convert, 0) + return 'foo_converted' + + parser = argparse.ArgumentParser() + parser.add_argument('--foo', type=spam, default=0) + args = parser.parse_args([]) + # foo should *not* be converted because its default is not a string. + self.assertEqual(NS(foo=0), args) + + def test_type_function_call_with_string_default(self): + def spam(int_to_convert): + return 'foo_converted' + + parser = argparse.ArgumentParser() + parser.add_argument('--foo', type=spam, default='0') + args = parser.parse_args([]) + # foo is converted because its default is a string. + self.assertEqual(NS(foo='foo_converted'), args) + + def test_no_double_type_conversion_of_default(self): + def extend(str_to_convert): + return str_to_convert + '*' + + parser = argparse.ArgumentParser() + parser.add_argument('--test', type=extend, default='*') + args = parser.parse_args([]) + # The test argument will be two stars, one coming from the default + # value and one coming from the type conversion being called exactly + # once. + self.assertEqual(NS(test='**'), args) + + def test_issue_15906(self): + # Issue #15906: When action='append', type=str, default=[] are + # providing, the dest value was the string representation "[]" when it + # should have been an empty list. + parser = argparse.ArgumentParser() + parser.add_argument('--test', dest='test', type=str, + default=[], action='append') + args = parser.parse_args([]) + self.assertEqual(args.test, []) + +# ====================== +# parse_known_args tests +# ====================== + +class TestParseKnownArgs(TestCase): + + def test_arguments_tuple(self): + parser = argparse.ArgumentParser() + parser.parse_args(()) + + def test_arguments_list(self): + parser = argparse.ArgumentParser() + parser.parse_args([]) + + def test_arguments_tuple_positional(self): + parser = argparse.ArgumentParser() + parser.add_argument('x') + parser.parse_args(('x',)) + + def test_arguments_list_positional(self): + parser = argparse.ArgumentParser() + parser.add_argument('x') + parser.parse_args(['x']) + + def test_optionals(self): + parser = argparse.ArgumentParser() + parser.add_argument('--foo') + args, extras = parser.parse_known_args('--foo F --bar --baz'.split()) + self.assertEqual(NS(foo='F'), args) + self.assertEqual(['--bar', '--baz'], extras) + + def test_mixed(self): + parser = argparse.ArgumentParser() + parser.add_argument('-v', nargs='?', const=1, type=int) + parser.add_argument('--spam', action='store_false') + parser.add_argument('badger') + + argv = ["B", "C", "--foo", "-v", "3", "4"] + args, extras = parser.parse_known_args(argv) + self.assertEqual(NS(v=3, spam=True, badger="B"), args) + self.assertEqual(["C", "--foo", "4"], extras) + +# =========================== +# parse_intermixed_args tests +# =========================== + +class TestIntermixedArgs(TestCase): + def test_basic(self): + # test parsing intermixed optionals and positionals + parser = argparse.ArgumentParser(prog='PROG') + parser.add_argument('--foo', dest='foo') + bar = parser.add_argument('--bar', dest='bar', required=True) + parser.add_argument('cmd') + parser.add_argument('rest', nargs='*', type=int) + argv = 'cmd --foo x 1 --bar y 2 3'.split() + args = parser.parse_intermixed_args(argv) + # rest gets [1,2,3] despite the foo and bar strings + self.assertEqual(NS(bar='y', cmd='cmd', foo='x', rest=[1, 2, 3]), args) + + args, extras = parser.parse_known_args(argv) + # cannot parse the '1,2,3' + self.assertEqual(NS(bar='y', cmd='cmd', foo='x', rest=[]), args) + self.assertEqual(["1", "2", "3"], extras) + + argv = 'cmd --foo x 1 --error 2 --bar y 3'.split() + args, extras = parser.parse_known_intermixed_args(argv) + # unknown optionals go into extras + self.assertEqual(NS(bar='y', cmd='cmd', foo='x', rest=[1]), args) + self.assertEqual(['--error', '2', '3'], extras) + + # restores attributes that were temporarily changed + self.assertIsNone(parser.usage) + self.assertEqual(bar.required, True) + + def test_remainder(self): + # Intermixed and remainder are incompatible + parser = ErrorRaisingArgumentParser(prog='PROG') + parser.add_argument('-z') + parser.add_argument('x') + parser.add_argument('y', nargs='...') + argv = 'X A B -z Z'.split() + # intermixed fails with '...' (also 'A...') + # self.assertRaises(TypeError, parser.parse_intermixed_args, argv) + with self.assertRaises(TypeError) as cm: + parser.parse_intermixed_args(argv) + self.assertRegex(str(cm.exception), r'\.\.\.') + + def test_exclusive(self): + # mutually exclusive group; intermixed works fine + parser = ErrorRaisingArgumentParser(prog='PROG') + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--foo', action='store_true', help='FOO') + group.add_argument('--spam', help='SPAM') + parser.add_argument('badger', nargs='*', default='X', help='BADGER') + args = parser.parse_intermixed_args('1 --foo 2'.split()) + self.assertEqual(NS(badger=['1', '2'], foo=True, spam=None), args) + self.assertRaises(ArgumentParserError, parser.parse_intermixed_args, '1 2'.split()) + self.assertEqual(group.required, True) + + def test_exclusive_incompatible(self): + # mutually exclusive group including positional - fail + parser = ErrorRaisingArgumentParser(prog='PROG') + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--foo', action='store_true', help='FOO') + group.add_argument('--spam', help='SPAM') + group.add_argument('badger', nargs='*', default='X', help='BADGER') + self.assertRaises(TypeError, parser.parse_intermixed_args, []) + self.assertEqual(group.required, True) + +class TestIntermixedMessageContentError(TestCase): + # case where Intermixed gives different error message + # error is raised by 1st parsing step + def test_missing_argument_name_in_message(self): + parser = ErrorRaisingArgumentParser(prog='PROG', usage='') + parser.add_argument('req_pos', type=str) + parser.add_argument('-req_opt', type=int, required=True) + + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args([]) + msg = str(cm.exception) + self.assertRegex(msg, 'req_pos') + self.assertRegex(msg, 'req_opt') + + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_intermixed_args([]) + msg = str(cm.exception) + self.assertNotRegex(msg, 'req_pos') + self.assertRegex(msg, 'req_opt') + +# ========================== +# add_argument metavar tests +# ========================== + +class TestAddArgumentMetavar(TestCase): + + EXPECTED_MESSAGE = "length of metavar tuple does not match nargs" + + def do_test_no_exception(self, nargs, metavar): + parser = argparse.ArgumentParser() + parser.add_argument("--foo", nargs=nargs, metavar=metavar) + + def do_test_exception(self, nargs, metavar): + parser = argparse.ArgumentParser() + with self.assertRaises(ValueError) as cm: + parser.add_argument("--foo", nargs=nargs, metavar=metavar) + self.assertEqual(cm.exception.args[0], self.EXPECTED_MESSAGE) + + # Unit tests for different values of metavar when nargs=None + + def test_nargs_None_metavar_string(self): + self.do_test_no_exception(nargs=None, metavar="1") + + def test_nargs_None_metavar_length0(self): + self.do_test_exception(nargs=None, metavar=tuple()) + + def test_nargs_None_metavar_length1(self): + self.do_test_no_exception(nargs=None, metavar=("1",)) + + def test_nargs_None_metavar_length2(self): + self.do_test_exception(nargs=None, metavar=("1", "2")) + + def test_nargs_None_metavar_length3(self): + self.do_test_exception(nargs=None, metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=? + + def test_nargs_optional_metavar_string(self): + self.do_test_no_exception(nargs="?", metavar="1") + + def test_nargs_optional_metavar_length0(self): + self.do_test_exception(nargs="?", metavar=tuple()) + + def test_nargs_optional_metavar_length1(self): + self.do_test_no_exception(nargs="?", metavar=("1",)) + + def test_nargs_optional_metavar_length2(self): + self.do_test_exception(nargs="?", metavar=("1", "2")) + + def test_nargs_optional_metavar_length3(self): + self.do_test_exception(nargs="?", metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=* + + def test_nargs_zeroormore_metavar_string(self): + self.do_test_no_exception(nargs="*", metavar="1") + + def test_nargs_zeroormore_metavar_length0(self): + self.do_test_exception(nargs="*", metavar=tuple()) + + def test_nargs_zeroormore_metavar_length1(self): + self.do_test_exception(nargs="*", metavar=("1",)) + + def test_nargs_zeroormore_metavar_length2(self): + self.do_test_no_exception(nargs="*", metavar=("1", "2")) + + def test_nargs_zeroormore_metavar_length3(self): + self.do_test_exception(nargs="*", metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=+ + + def test_nargs_oneormore_metavar_string(self): + self.do_test_no_exception(nargs="+", metavar="1") + + def test_nargs_oneormore_metavar_length0(self): + self.do_test_exception(nargs="+", metavar=tuple()) + + def test_nargs_oneormore_metavar_length1(self): + self.do_test_exception(nargs="+", metavar=("1",)) + + def test_nargs_oneormore_metavar_length2(self): + self.do_test_no_exception(nargs="+", metavar=("1", "2")) + + def test_nargs_oneormore_metavar_length3(self): + self.do_test_exception(nargs="+", metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=... + + def test_nargs_remainder_metavar_string(self): + self.do_test_no_exception(nargs="...", metavar="1") + + def test_nargs_remainder_metavar_length0(self): + self.do_test_no_exception(nargs="...", metavar=tuple()) + + def test_nargs_remainder_metavar_length1(self): + self.do_test_no_exception(nargs="...", metavar=("1",)) + + def test_nargs_remainder_metavar_length2(self): + self.do_test_no_exception(nargs="...", metavar=("1", "2")) + + def test_nargs_remainder_metavar_length3(self): + self.do_test_no_exception(nargs="...", metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=A... + + def test_nargs_parser_metavar_string(self): + self.do_test_no_exception(nargs="A...", metavar="1") + + def test_nargs_parser_metavar_length0(self): + self.do_test_exception(nargs="A...", metavar=tuple()) + + def test_nargs_parser_metavar_length1(self): + self.do_test_no_exception(nargs="A...", metavar=("1",)) + + def test_nargs_parser_metavar_length2(self): + self.do_test_exception(nargs="A...", metavar=("1", "2")) + + def test_nargs_parser_metavar_length3(self): + self.do_test_exception(nargs="A...", metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=1 + + def test_nargs_1_metavar_string(self): + self.do_test_no_exception(nargs=1, metavar="1") + + def test_nargs_1_metavar_length0(self): + self.do_test_exception(nargs=1, metavar=tuple()) + + def test_nargs_1_metavar_length1(self): + self.do_test_no_exception(nargs=1, metavar=("1",)) + + def test_nargs_1_metavar_length2(self): + self.do_test_exception(nargs=1, metavar=("1", "2")) + + def test_nargs_1_metavar_length3(self): + self.do_test_exception(nargs=1, metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=2 + + def test_nargs_2_metavar_string(self): + self.do_test_no_exception(nargs=2, metavar="1") + + def test_nargs_2_metavar_length0(self): + self.do_test_exception(nargs=2, metavar=tuple()) + + def test_nargs_2_metavar_length1(self): + self.do_test_exception(nargs=2, metavar=("1",)) + + def test_nargs_2_metavar_length2(self): + self.do_test_no_exception(nargs=2, metavar=("1", "2")) + + def test_nargs_2_metavar_length3(self): + self.do_test_exception(nargs=2, metavar=("1", "2", "3")) + + # Unit tests for different values of metavar when nargs=3 + + def test_nargs_3_metavar_string(self): + self.do_test_no_exception(nargs=3, metavar="1") + + def test_nargs_3_metavar_length0(self): + self.do_test_exception(nargs=3, metavar=tuple()) + + def test_nargs_3_metavar_length1(self): + self.do_test_exception(nargs=3, metavar=("1",)) + + def test_nargs_3_metavar_length2(self): + self.do_test_exception(nargs=3, metavar=("1", "2")) + + def test_nargs_3_metavar_length3(self): + self.do_test_no_exception(nargs=3, metavar=("1", "2", "3")) + +# ============================ +# from argparse import * tests +# ============================ + +class TestImportStar(TestCase): + + def test(self): + for name in argparse.__all__: + self.assertTrue(hasattr(argparse, name)) + + def test_all_exports_everything_but_modules(self): + items = [ + name + for name, value in vars(argparse).items() + if not (name.startswith("_") or name == 'ngettext') + if not inspect.ismodule(value) + ] + self.assertEqual(sorted(items), sorted(argparse.__all__)) + + +class TestWrappingMetavar(TestCase): + + def setUp(self): + super().setUp() + self.parser = ErrorRaisingArgumentParser( + 'this_is_spammy_prog_with_a_long_name_sorry_about_the_name' + ) + # this metavar was triggering library assertion errors due to usage + # message formatting incorrectly splitting on the ] chars within + metavar = '' + self.parser.add_argument('--proxy', metavar=metavar) + + def test_help_with_metavar(self): + help_text = self.parser.format_help() + self.assertEqual(help_text, textwrap.dedent('''\ + usage: this_is_spammy_prog_with_a_long_name_sorry_about_the_name + [-h] [--proxy ] + + optional arguments: + -h, --help show this help message and exit + --proxy + ''')) + + +def test_main(): + support.run_unittest(__name__) + # Remove global references to avoid looking like we have refleaks. + RFile.seen = {} + WFile.seen = set() + + + +if __name__ == '__main__': + test_main() diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py index 6bdbfe9f0a..7cca83d783 100644 --- a/Lib/test/test_array.py +++ b/Lib/test/test_array.py @@ -341,8 +341,6 @@ def test_iterator_pickle(self): a.fromlist(data2) self.assertEqual(list(it), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_exhausted_iterator(self): a = array.array(self.typecode, self.example) self.assertEqual(list(a), list(self.example)) diff --git a/Lib/test/test_binop.py b/Lib/test/test_binop.py new file mode 100644 index 0000000000..299af09c49 --- /dev/null +++ b/Lib/test/test_binop.py @@ -0,0 +1,440 @@ +"""Tests for binary operators on subtypes of built-in types.""" + +import unittest +from operator import eq, le, ne +from abc import ABCMeta + +def gcd(a, b): + """Greatest common divisor using Euclid's algorithm.""" + while a: + a, b = b%a, a + return b + +def isint(x): + """Test whether an object is an instance of int.""" + return isinstance(x, int) + +def isnum(x): + """Test whether an object is an instance of a built-in numeric type.""" + for T in int, float, complex: + if isinstance(x, T): + return 1 + return 0 + +def isRat(x): + """Test whether an object is an instance of the Rat class.""" + return isinstance(x, Rat) + +class Rat(object): + + """Rational number implemented as a normalized pair of ints.""" + + __slots__ = ['_Rat__num', '_Rat__den'] + + def __init__(self, num=0, den=1): + """Constructor: Rat([num[, den]]). + + The arguments must be ints, and default to (0, 1).""" + if not isint(num): + raise TypeError("Rat numerator must be int (%r)" % num) + if not isint(den): + raise TypeError("Rat denominator must be int (%r)" % den) + # But the zero is always on + if den == 0: + raise ZeroDivisionError("zero denominator") + g = gcd(den, num) + self.__num = int(num//g) + self.__den = int(den//g) + + def _get_num(self): + """Accessor function for read-only 'num' attribute of Rat.""" + return self.__num + num = property(_get_num, None) + + def _get_den(self): + """Accessor function for read-only 'den' attribute of Rat.""" + return self.__den + den = property(_get_den, None) + + def __repr__(self): + """Convert a Rat to a string resembling a Rat constructor call.""" + return "Rat(%d, %d)" % (self.__num, self.__den) + + def __str__(self): + """Convert a Rat to a string resembling a decimal numeric value.""" + return str(float(self)) + + def __float__(self): + """Convert a Rat to a float.""" + return self.__num*1.0/self.__den + + def __int__(self): + """Convert a Rat to an int; self.den must be 1.""" + if self.__den == 1: + try: + return int(self.__num) + except OverflowError: + raise OverflowError("%s too large to convert to int" % + repr(self)) + raise ValueError("can't convert %s to int" % repr(self)) + + def __add__(self, other): + """Add two Rats, or a Rat and a number.""" + if isint(other): + other = Rat(other) + if isRat(other): + return Rat(self.__num*other.__den + other.__num*self.__den, + self.__den*other.__den) + if isnum(other): + return float(self) + other + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other): + """Subtract two Rats, or a Rat and a number.""" + if isint(other): + other = Rat(other) + if isRat(other): + return Rat(self.__num*other.__den - other.__num*self.__den, + self.__den*other.__den) + if isnum(other): + return float(self) - other + return NotImplemented + + def __rsub__(self, other): + """Subtract two Rats, or a Rat and a number (reversed args).""" + if isint(other): + other = Rat(other) + if isRat(other): + return Rat(other.__num*self.__den - self.__num*other.__den, + self.__den*other.__den) + if isnum(other): + return other - float(self) + return NotImplemented + + def __mul__(self, other): + """Multiply two Rats, or a Rat and a number.""" + if isRat(other): + return Rat(self.__num*other.__num, self.__den*other.__den) + if isint(other): + return Rat(self.__num*other, self.__den) + if isnum(other): + return float(self)*other + return NotImplemented + + __rmul__ = __mul__ + + def __truediv__(self, other): + """Divide two Rats, or a Rat and a number.""" + if isRat(other): + return Rat(self.__num*other.__den, self.__den*other.__num) + if isint(other): + return Rat(self.__num, self.__den*other) + if isnum(other): + return float(self) / other + return NotImplemented + + def __rtruediv__(self, other): + """Divide two Rats, or a Rat and a number (reversed args).""" + if isRat(other): + return Rat(other.__num*self.__den, other.__den*self.__num) + if isint(other): + return Rat(other*self.__den, self.__num) + if isnum(other): + return other / float(self) + return NotImplemented + + def __floordiv__(self, other): + """Divide two Rats, returning the floored result.""" + if isint(other): + other = Rat(other) + elif not isRat(other): + return NotImplemented + x = self/other + return x.__num // x.__den + + def __rfloordiv__(self, other): + """Divide two Rats, returning the floored result (reversed args).""" + x = other/self + return x.__num // x.__den + + def __divmod__(self, other): + """Divide two Rats, returning quotient and remainder.""" + if isint(other): + other = Rat(other) + elif not isRat(other): + return NotImplemented + x = self//other + return (x, self - other * x) + + def __rdivmod__(self, other): + """Divide two Rats, returning quotient and remainder (reversed args).""" + if isint(other): + other = Rat(other) + elif not isRat(other): + return NotImplemented + return divmod(other, self) + + def __mod__(self, other): + """Take one Rat modulo another.""" + return divmod(self, other)[1] + + def __rmod__(self, other): + """Take one Rat modulo another (reversed args).""" + return divmod(other, self)[1] + + def __eq__(self, other): + """Compare two Rats for equality.""" + if isint(other): + return self.__den == 1 and self.__num == other + if isRat(other): + return self.__num == other.__num and self.__den == other.__den + if isnum(other): + return float(self) == other + return NotImplemented + +class RatTestCase(unittest.TestCase): + """Unit tests for Rat class and its support utilities.""" + + def test_gcd(self): + self.assertEqual(gcd(10, 12), 2) + self.assertEqual(gcd(10, 15), 5) + self.assertEqual(gcd(10, 11), 1) + self.assertEqual(gcd(100, 15), 5) + self.assertEqual(gcd(-10, 2), -2) + self.assertEqual(gcd(10, -2), 2) + self.assertEqual(gcd(-10, -2), -2) + for i in range(1, 20): + for j in range(1, 20): + self.assertTrue(gcd(i, j) > 0) + self.assertTrue(gcd(-i, j) < 0) + self.assertTrue(gcd(i, -j) > 0) + self.assertTrue(gcd(-i, -j) < 0) + + def test_constructor(self): + a = Rat(10, 15) + self.assertEqual(a.num, 2) + self.assertEqual(a.den, 3) + a = Rat(10, -15) + self.assertEqual(a.num, -2) + self.assertEqual(a.den, 3) + a = Rat(-10, 15) + self.assertEqual(a.num, -2) + self.assertEqual(a.den, 3) + a = Rat(-10, -15) + self.assertEqual(a.num, 2) + self.assertEqual(a.den, 3) + a = Rat(7) + self.assertEqual(a.num, 7) + self.assertEqual(a.den, 1) + try: + a = Rat(1, 0) + except ZeroDivisionError: + pass + else: + self.fail("Rat(1, 0) didn't raise ZeroDivisionError") + for bad in "0", 0.0, 0j, (), [], {}, None, Rat, unittest: + try: + a = Rat(bad) + except TypeError: + pass + else: + self.fail("Rat(%r) didn't raise TypeError" % bad) + try: + a = Rat(1, bad) + except TypeError: + pass + else: + self.fail("Rat(1, %r) didn't raise TypeError" % bad) + + def test_add(self): + self.assertEqual(Rat(2, 3) + Rat(1, 3), 1) + self.assertEqual(Rat(2, 3) + 1, Rat(5, 3)) + self.assertEqual(1 + Rat(2, 3), Rat(5, 3)) + self.assertEqual(1.0 + Rat(1, 2), 1.5) + self.assertEqual(Rat(1, 2) + 1.0, 1.5) + + def test_sub(self): + self.assertEqual(Rat(7, 2) - Rat(7, 5), Rat(21, 10)) + self.assertEqual(Rat(7, 5) - 1, Rat(2, 5)) + self.assertEqual(1 - Rat(3, 5), Rat(2, 5)) + self.assertEqual(Rat(3, 2) - 1.0, 0.5) + self.assertEqual(1.0 - Rat(1, 2), 0.5) + + def test_mul(self): + self.assertEqual(Rat(2, 3) * Rat(5, 7), Rat(10, 21)) + self.assertEqual(Rat(10, 3) * 3, 10) + self.assertEqual(3 * Rat(10, 3), 10) + self.assertEqual(Rat(10, 5) * 0.5, 1.0) + self.assertEqual(0.5 * Rat(10, 5), 1.0) + + def test_div(self): + self.assertEqual(Rat(10, 3) / Rat(5, 7), Rat(14, 3)) + self.assertEqual(Rat(10, 3) / 3, Rat(10, 9)) + self.assertEqual(2 / Rat(5), Rat(2, 5)) + self.assertEqual(3.0 * Rat(1, 2), 1.5) + self.assertEqual(Rat(1, 2) * 3.0, 1.5) + + def test_floordiv(self): + self.assertEqual(Rat(10) // Rat(4), 2) + self.assertEqual(Rat(10, 3) // Rat(4, 3), 2) + self.assertEqual(Rat(10) // 4, 2) + self.assertEqual(10 // Rat(4), 2) + + def test_eq(self): + self.assertEqual(Rat(10), Rat(20, 2)) + self.assertEqual(Rat(10), 10) + self.assertEqual(10, Rat(10)) + self.assertEqual(Rat(10), 10.0) + self.assertEqual(10.0, Rat(10)) + + def test_true_div(self): + self.assertEqual(Rat(10, 3) / Rat(5, 7), Rat(14, 3)) + self.assertEqual(Rat(10, 3) / 3, Rat(10, 9)) + self.assertEqual(2 / Rat(5), Rat(2, 5)) + self.assertEqual(3.0 * Rat(1, 2), 1.5) + self.assertEqual(Rat(1, 2) * 3.0, 1.5) + self.assertEqual(eval('1/2'), 0.5) + + # XXX Ran out of steam; TO DO: divmod, div, future division + + +class OperationLogger: + """Base class for classes with operation logging.""" + def __init__(self, logger): + self.logger = logger + def log_operation(self, *args): + self.logger(*args) + +def op_sequence(op, *classes): + """Return the sequence of operations that results from applying + the operation `op` to instances of the given classes.""" + log = [] + instances = [] + for c in classes: + instances.append(c(log.append)) + + try: + op(*instances) + except TypeError: + pass + return log + +class A(OperationLogger): + def __eq__(self, other): + self.log_operation('A.__eq__') + return NotImplemented + def __le__(self, other): + self.log_operation('A.__le__') + return NotImplemented + def __ge__(self, other): + self.log_operation('A.__ge__') + return NotImplemented + +class B(OperationLogger, metaclass=ABCMeta): + def __eq__(self, other): + self.log_operation('B.__eq__') + return NotImplemented + def __le__(self, other): + self.log_operation('B.__le__') + return NotImplemented + def __ge__(self, other): + self.log_operation('B.__ge__') + return NotImplemented + +class C(B): + def __eq__(self, other): + self.log_operation('C.__eq__') + return NotImplemented + def __le__(self, other): + self.log_operation('C.__le__') + return NotImplemented + def __ge__(self, other): + self.log_operation('C.__ge__') + return NotImplemented + +class V(OperationLogger): + """Virtual subclass of B""" + def __eq__(self, other): + self.log_operation('V.__eq__') + return NotImplemented + def __le__(self, other): + self.log_operation('V.__le__') + return NotImplemented + def __ge__(self, other): + self.log_operation('V.__ge__') + return NotImplemented +B.register(V) + + +class OperationOrderTests(unittest.TestCase): + def test_comparison_orders(self): + self.assertEqual(op_sequence(eq, A, A), ['A.__eq__', 'A.__eq__']) + self.assertEqual(op_sequence(eq, A, B), ['A.__eq__', 'B.__eq__']) + self.assertEqual(op_sequence(eq, B, A), ['B.__eq__', 'A.__eq__']) + # C is a subclass of B, so C.__eq__ is called first + self.assertEqual(op_sequence(eq, B, C), ['C.__eq__', 'B.__eq__']) + self.assertEqual(op_sequence(eq, C, B), ['C.__eq__', 'B.__eq__']) + + self.assertEqual(op_sequence(le, A, A), ['A.__le__', 'A.__ge__']) + self.assertEqual(op_sequence(le, A, B), ['A.__le__', 'B.__ge__']) + self.assertEqual(op_sequence(le, B, A), ['B.__le__', 'A.__ge__']) + self.assertEqual(op_sequence(le, B, C), ['C.__ge__', 'B.__le__']) + self.assertEqual(op_sequence(le, C, B), ['C.__le__', 'B.__ge__']) + + self.assertTrue(issubclass(V, B)) + self.assertEqual(op_sequence(eq, B, V), ['B.__eq__', 'V.__eq__']) + self.assertEqual(op_sequence(le, B, V), ['B.__le__', 'V.__ge__']) + +class SupEq(object): + """Class that can test equality""" + def __eq__(self, other): + return True + +class S(SupEq): + """Subclass of SupEq that should fail""" + __eq__ = None + +class F(object): + """Independent class that should fall back""" + +class X(object): + """Independent class that should fail""" + __eq__ = None + +class SN(SupEq): + """Subclass of SupEq that can test equality, but not non-equality""" + __ne__ = None + +class XN: + """Independent class that can test equality, but not non-equality""" + def __eq__(self, other): + return True + __ne__ = None + +class FallbackBlockingTests(unittest.TestCase): + """Unit tests for None method blocking""" + + def test_fallback_rmethod_blocking(self): + e, f, s, x = SupEq(), F(), S(), X() + self.assertEqual(e, e) + self.assertEqual(e, f) + self.assertEqual(f, e) + # left operand is checked first + self.assertEqual(e, x) + self.assertRaises(TypeError, eq, x, e) + # S is a subclass, so it's always checked first + self.assertRaises(TypeError, eq, e, s) + self.assertRaises(TypeError, eq, s, e) + + def test_fallback_ne_blocking(self): + e, sn, xn = SupEq(), SN(), XN() + self.assertFalse(e != e) + self.assertRaises(TypeError, ne, e, sn) + self.assertRaises(TypeError, ne, sn, e) + self.assertFalse(e != xn) + self.assertRaises(TypeError, ne, xn, e) + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_decorators.py b/Lib/test/test_decorators.py index d0a2ec9fdd..d6c4cd5666 100644 --- a/Lib/test/test_decorators.py +++ b/Lib/test/test_decorators.py @@ -151,21 +151,18 @@ def double(x): self.assertEqual(counts['double'], 4) def test_errors(self): - # Test syntax restrictions - these are all compile-time errors: - # - for expr in [ "1+2", "x[3]", "(1, 2)" ]: - # Sanity check: is expr is a valid expression by itself? - compile(expr, "testexpr", "exec") - - codestr = "@%s\ndef f(): pass" % expr - self.assertRaises(SyntaxError, compile, codestr, "test", "exec") - # You can't put multiple decorators on a single line: - # - self.assertRaises(SyntaxError, compile, - "@f1 @f2\ndef f(): pass", "test", "exec") + # Test SyntaxErrors: + for stmt in ("x,", "x, y", "x = y", "pass", "import sys"): + compile(stmt, "test", "exec") # Sanity check. + with self.assertRaises(SyntaxError): + compile(f"@{stmt}\ndef f(): pass", "test", "exec") - # Test runtime errors + # Test TypeErrors that used to be SyntaxErrors: + for expr in ("1.+2j", "[1, 2][-1]", "(1, 2)", "True", "...", "None"): + compile(expr, "test", "eval") # Sanity check. + with self.assertRaises(TypeError): + exec(f"@{expr}\ndef f(): pass") def unimp(func): raise NotImplementedError @@ -179,6 +176,18 @@ def unimp(func): code = compile(codestr, "test", "exec") self.assertRaises(exc, eval, code, context) + def test_expressions(self): + for expr in ( + ## original tests + # "(x,)", "(x, y)", "x := y", "(x := y)", "x @y", "(x @ y)", "x[0]", + # "w[x].y.z", "w + x - (y + z)", "x(y)()(z)", "[w, x, y][z]", "x.y", + + ##same without := + "(x,)", "(x, y)", "x @y", "(x @ y)", "x[0]", + "w[x].y.z", "w + x - (y + z)", "x(y)()(z)", "[w, x, y][z]", "x.y", + ): + compile(f"@{expr}\ndef f(): pass", "test", "exec") + def test_double(self): class C(object): @funcattrs(abc=1, xyz="haha") @@ -265,6 +274,45 @@ def bar(): return 42 self.assertEqual(bar(), 42) self.assertEqual(actions, expected_actions) + def test_wrapped_descriptor_inside_classmethod(self): + class BoundWrapper: + def __init__(self, wrapped): + self.__wrapped__ = wrapped + + def __call__(self, *args, **kwargs): + return self.__wrapped__(*args, **kwargs) + + class Wrapper: + def __init__(self, wrapped): + self.__wrapped__ = wrapped + + def __get__(self, instance, owner): + bound_function = self.__wrapped__.__get__(instance, owner) + return BoundWrapper(bound_function) + + def decorator(wrapped): + return Wrapper(wrapped) + + class Class: + @decorator + @classmethod + def inner(cls): + # This should already work. + return 'spam' + + @classmethod + @decorator + def outer(cls): + # Raised TypeError with a message saying that the 'Wrapper' + # object is not callable. + return 'eggs' + + self.assertEqual(Class.inner(), 'spam') + #self.assertEqual(Class.outer(), 'eggs') # TODO RustPython + self.assertEqual(Class().inner(), 'spam') + #self.assertEqual(Class().outer(), 'eggs') # TODO RustPython + + class TestClassDecorators(unittest.TestCase): def test_simple(self): @@ -301,4 +349,4 @@ class C(object): pass self.assertEqual(C.extra, 'second') if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py new file mode 100644 index 0000000000..de64c51e77 --- /dev/null +++ b/Lib/test/test_long.py @@ -0,0 +1,1401 @@ +import unittest +from test import support + +import sys + +import random +import math +import array + +# SHIFT should match the value in longintrepr.h for best testing. +SHIFT = 32 #sys.int_info.bits_per_digit # TODO RustPython int_info not supported +BASE = 2 ** SHIFT +MASK = BASE - 1 +KARATSUBA_CUTOFF = 70 # from longobject.c + +# Max number of base BASE digits to use in test cases. Doubling +# this will more than double the runtime. +MAXDIGITS = 15 + +# build some special values +special = [0, 1, 2, BASE, BASE >> 1, 0x5555555555555555, 0xaaaaaaaaaaaaaaaa] +# some solid strings of one bits +p2 = 4 # 0 and 1 already added +for i in range(2*SHIFT): + special.append(p2 - 1) + p2 = p2 << 1 +del p2 +# add complements & negations +special += [~x for x in special] + [-x for x in special] + +DBL_MAX = 1.7976931348623157E+308 # sys.float_info.max # TODO RustPython +DBL_MAX_EXP = 1024 # sys.float_info.max_exp +DBL_MIN_EXP = -1021 # sys.float_info.min_exp +DBL_MANT_DIG = 53 # sys.float_info.mant_dig +DBL_MIN_OVERFLOW = 2**DBL_MAX_EXP - 2**(DBL_MAX_EXP - DBL_MANT_DIG - 1) + + +# Pure Python version of correctly-rounded integer-to-float conversion. +def int_to_float(n): + """ + Correctly-rounded integer-to-float conversion. + """ + # Constants, depending only on the floating-point format in use. + # We use an extra 2 bits of precision for rounding purposes. + PRECISION = sys.float_info.mant_dig + 2 + SHIFT_MAX = sys.float_info.max_exp - PRECISION + Q_MAX = 1 << PRECISION + ROUND_HALF_TO_EVEN_CORRECTION = [0, -1, -2, 1, 0, -1, 2, 1] + + # Reduce to the case where n is positive. + if n == 0: + return 0.0 + elif n < 0: + return -int_to_float(-n) + + # Convert n to a 'floating-point' number q * 2**shift, where q is an + # integer with 'PRECISION' significant bits. When shifting n to create q, + # the least significant bit of q is treated as 'sticky'. That is, the + # least significant bit of q is set if either the corresponding bit of n + # was already set, or any one of the bits of n lost in the shift was set. + shift = n.bit_length() - PRECISION + q = n << -shift if shift < 0 else (n >> shift) | bool(n & ~(-1 << shift)) + + # Round half to even (actually rounds to the nearest multiple of 4, + # rounding ties to a multiple of 8). + q += ROUND_HALF_TO_EVEN_CORRECTION[q & 7] + + # Detect overflow. + if shift + (q == Q_MAX) > SHIFT_MAX: + raise OverflowError("integer too large to convert to float") + + # Checks: q is exactly representable, and q**2**shift doesn't overflow. + assert q % 4 == 0 and q // 4 <= 2**(sys.float_info.mant_dig) + assert q * 2**shift <= sys.float_info.max + + # Some circularity here, since float(q) is doing an int-to-float + # conversion. But here q is of bounded size, and is exactly representable + # as a float. In a low-level C-like language, this operation would be a + # simple cast (e.g., from unsigned long long to double). + return math.ldexp(float(q), shift) + + +# pure Python version of correctly-rounded true division +def truediv(a, b): + """Correctly-rounded true division for integers.""" + negative = a^b < 0 + a, b = abs(a), abs(b) + + # exceptions: division by zero, overflow + if not b: + raise ZeroDivisionError("division by zero") + if a >= DBL_MIN_OVERFLOW * b: + raise OverflowError("int/int too large to represent as a float") + + # find integer d satisfying 2**(d - 1) <= a/b < 2**d + d = a.bit_length() - b.bit_length() + if d >= 0 and a >= 2**d * b or d < 0 and a * 2**-d >= b: + d += 1 + + # compute 2**-exp * a / b for suitable exp + exp = max(d, DBL_MIN_EXP) - DBL_MANT_DIG + a, b = a << max(-exp, 0), b << max(exp, 0) + q, r = divmod(a, b) + + # round-half-to-even: fractional part is r/b, which is > 0.5 iff + # 2*r > b, and == 0.5 iff 2*r == b. + if 2*r > b or 2*r == b and q % 2 == 1: + q += 1 + + result = math.ldexp(q, exp) + return -result if negative else result + + +class LongTest(unittest.TestCase): + + # Get quasi-random long consisting of ndigits digits (in base BASE). + # quasi == the most-significant digit will not be 0, and the number + # is constructed to contain long strings of 0 and 1 bits. These are + # more likely than random bits to provoke digit-boundary errors. + # The sign of the number is also random. + + def getran(self, ndigits): + self.assertGreater(ndigits, 0) + nbits_hi = ndigits * SHIFT + nbits_lo = nbits_hi - SHIFT + 1 + answer = 0 + nbits = 0 + r = int(random.random() * (SHIFT * 2)) | 1 # force 1 bits to start + while nbits < nbits_lo: + bits = (r >> 1) + 1 + bits = min(bits, nbits_hi - nbits) + self.assertTrue(1 <= bits <= SHIFT) + nbits = nbits + bits + answer = answer << bits + if r & 1: + answer = answer | ((1 << bits) - 1) + r = int(random.random() * (SHIFT * 2)) + self.assertTrue(nbits_lo <= nbits <= nbits_hi) + if random.random() < 0.5: + answer = -answer + return answer + + # Get random long consisting of ndigits random digits (relative to base + # BASE). The sign bit is also random. + + def getran2(ndigits): + answer = 0 + for i in range(ndigits): + answer = (answer << SHIFT) | random.randint(0, MASK) + if random.random() < 0.5: + answer = -answer + return answer + + def check_division(self, x, y): + eq = self.assertEqual + with self.subTest(x=x, y=y): + q, r = divmod(x, y) + q2, r2 = x//y, x%y + pab, pba = x*y, y*x + eq(pab, pba, "multiplication does not commute") + eq(q, q2, "divmod returns different quotient than /") + eq(r, r2, "divmod returns different mod than %") + eq(x, q*y + r, "x != q*y + r after divmod") + if y > 0: + self.assertTrue(0 <= r < y, "bad mod from divmod") + else: + self.assertTrue(y < r <= 0, "bad mod from divmod") + + def test_division(self): + digits = list(range(1, MAXDIGITS+1)) + list(range(KARATSUBA_CUTOFF, + KARATSUBA_CUTOFF + 14)) + digits.append(KARATSUBA_CUTOFF * 3) + for lenx in digits: + x = self.getran(lenx) + for leny in digits: + y = self.getran(leny) or 1 + self.check_division(x, y) + + # specific numbers chosen to exercise corner cases of the + # current long division implementation + + # 30-bit cases involving a quotient digit estimate of BASE+1 + self.check_division(1231948412290879395966702881, + 1147341367131428698) + self.check_division(815427756481275430342312021515587883, + 707270836069027745) + self.check_division(627976073697012820849443363563599041, + 643588798496057020) + self.check_division(1115141373653752303710932756325578065, + 1038556335171453937726882627) + # 30-bit cases that require the post-subtraction correction step + self.check_division(922498905405436751940989320930368494, + 949985870686786135626943396) + self.check_division(768235853328091167204009652174031844, + 1091555541180371554426545266) + + # 15-bit cases involving a quotient digit estimate of BASE+1 + self.check_division(20172188947443, 615611397) + self.check_division(1020908530270155025, 950795710) + self.check_division(128589565723112408, 736393718) + self.check_division(609919780285761575, 18613274546784) + # 15-bit cases that require the post-subtraction correction step + self.check_division(710031681576388032, 26769404391308) + self.check_division(1933622614268221, 30212853348836) + + + + def test_karatsuba(self): + digits = list(range(1, 5)) + list(range(KARATSUBA_CUTOFF, + KARATSUBA_CUTOFF + 10)) + digits.extend([KARATSUBA_CUTOFF * 10, KARATSUBA_CUTOFF * 100]) + + bits = [digit * SHIFT for digit in digits] + + # Test products of long strings of 1 bits -- (2**x-1)*(2**y-1) == + # 2**(x+y) - 2**x - 2**y + 1, so the proper result is easy to check. + for abits in bits: + a = (1 << abits) - 1 + for bbits in bits: + if bbits < abits: + continue + with self.subTest(abits=abits, bbits=bbits): + b = (1 << bbits) - 1 + x = a * b + y = ((1 << (abits + bbits)) - + (1 << abits) - + (1 << bbits) + + 1) + self.assertEqual(x, y) + + def check_bitop_identities_1(self, x): + eq = self.assertEqual + with self.subTest(x=x): + eq(x & 0, 0) + eq(x | 0, x) + eq(x ^ 0, x) + eq(x & -1, x) + eq(x | -1, -1) + eq(x ^ -1, ~x) + eq(x, ~~x) + eq(x & x, x) + eq(x | x, x) + eq(x ^ x, 0) + eq(x & ~x, 0) + eq(x | ~x, -1) + eq(x ^ ~x, -1) + eq(-x, 1 + ~x) + eq(-x, ~(x-1)) + for n in range(2*SHIFT): + p2 = 2 ** n + with self.subTest(x=x, n=n, p2=p2): + eq(x << n >> n, x) + eq(x // p2, x >> n) + eq(x * p2, x << n) + eq(x & -p2, x >> n << n) + eq(x & -p2, x & ~(p2 - 1)) + + def check_bitop_identities_2(self, x, y): + eq = self.assertEqual + with self.subTest(x=x, y=y): + eq(x & y, y & x) + eq(x | y, y | x) + eq(x ^ y, y ^ x) + eq(x ^ y ^ x, y) + eq(x & y, ~(~x | ~y)) + eq(x | y, ~(~x & ~y)) + eq(x ^ y, (x | y) & ~(x & y)) + eq(x ^ y, (x & ~y) | (~x & y)) + eq(x ^ y, (x | y) & (~x | ~y)) + + def check_bitop_identities_3(self, x, y, z): + eq = self.assertEqual + with self.subTest(x=x, y=y, z=z): + eq((x & y) & z, x & (y & z)) + eq((x | y) | z, x | (y | z)) + eq((x ^ y) ^ z, x ^ (y ^ z)) + eq(x & (y | z), (x & y) | (x & z)) + eq(x | (y & z), (x | y) & (x | z)) + + def test_bitop_identities(self): + for x in special: + self.check_bitop_identities_1(x) + digits = range(1, MAXDIGITS+1) + for lenx in digits: + x = self.getran(lenx) + self.check_bitop_identities_1(x) + for leny in digits: + y = self.getran(leny) + self.check_bitop_identities_2(x, y) + self.check_bitop_identities_3(x, y, self.getran((lenx + leny)//2)) + + def slow_format(self, x, base): + digits = [] + sign = 0 + if x < 0: + sign, x = 1, -x + while x: + x, r = divmod(x, base) + digits.append(int(r)) + digits.reverse() + digits = digits or [0] + return '-'[:sign] + \ + {2: '0b', 8: '0o', 10: '', 16: '0x'}[base] + \ + "".join("0123456789abcdef"[i] for i in digits) + + def check_format_1(self, x): + for base, mapper in (2, bin), (8, oct), (10, str), (10, repr), (16, hex): + got = mapper(x) + with self.subTest(x=x, mapper=mapper.__name__): + expected = self.slow_format(x, base) + self.assertEqual(got, expected) + with self.subTest(got=got): + self.assertEqual(int(got, 0), x) + + @unittest.expectedFailure # RustPython + def test_format(self): + for x in special: + self.check_format_1(x) + for i in range(10): + for lenx in range(1, MAXDIGITS+1): + x = self.getran(lenx) + self.check_format_1(x) + + def test_long(self): + # Check conversions from string + LL = [ + ('1' + '0'*20, 10**20), + ('1' + '0'*100, 10**100) + ] + for s, v in LL: + for sign in "", "+", "-": + for prefix in "", " ", "\t", " \t\t ": + ss = prefix + sign + s + vv = v + if sign == "-" and v is not ValueError: + vv = -v + try: + self.assertEqual(int(ss), vv) + except ValueError: + pass + + # trailing L should no longer be accepted... + self.assertRaises(ValueError, int, '123L') + self.assertRaises(ValueError, int, '123l') + self.assertRaises(ValueError, int, '0L') + self.assertRaises(ValueError, int, '-37L') + self.assertRaises(ValueError, int, '0x32L', 16) + self.assertRaises(ValueError, int, '1L', 21) + # ... but it's just a normal digit if base >= 22 + self.assertEqual(int('1L', 22), 43) + + # tests with base 0 + self.assertEqual(int('000', 0), 0) + self.assertEqual(int('0o123', 0), 83) + self.assertEqual(int('0x123', 0), 291) + self.assertEqual(int('0b100', 0), 4) + self.assertEqual(int(' 0O123 ', 0), 83) + self.assertEqual(int(' 0X123 ', 0), 291) + self.assertEqual(int(' 0B100 ', 0), 4) + self.assertEqual(int('0', 0), 0) + self.assertEqual(int('+0', 0), 0) + self.assertEqual(int('-0', 0), 0) + self.assertEqual(int('00', 0), 0) + self.assertRaises(ValueError, int, '08', 0) + #self.assertRaises(ValueError, int, '-012395', 0) # move to individual test case + + # invalid bases + invalid_bases = [-909, + 2**31-1, 2**31, -2**31, -2**31-1, + 2**63-1, 2**63, -2**63, -2**63-1, + 2**100, -2**100, + ] + for base in invalid_bases: + self.assertRaises(ValueError, int, '42', base) + + # Invalid unicode string + # See bpo-34087 + self.assertRaises(ValueError, int, '\u3053\u3093\u306b\u3061\u306f') + + @unittest.expectedFailure # TODO RustPython + def test_long_a(self): + self.assertRaises(ValueError, int, '-012395', 0) + + + @unittest.expectedFailure # TODO RustPython + def test_conversion(self): + + class JustLong: + # test that __long__ no longer used in 3.x + def __long__(self): + return 42 + self.assertRaises(TypeError, int, JustLong()) + + class LongTrunc: + # __long__ should be ignored in 3.x + def __long__(self): + return 42 + def __trunc__(self): + return 1729 + self.assertEqual(int(LongTrunc()), 1729) + + def check_float_conversion(self, n): + # Check that int -> float conversion behaviour matches + # that of the pure Python version above. + try: + actual = float(n) + except OverflowError: + actual = 'overflow' + + try: + expected = int_to_float(n) + except OverflowError: + expected = 'overflow' + + msg = ("Error in conversion of integer {} to float. " + "Got {}, expected {}.".format(n, actual, expected)) + self.assertEqual(actual, expected, msg) + + #@support.requires_IEEE_754 + @unittest.skip # TODO RustPython + def test_float_conversion(self): + + exact_values = [0, 1, 2, + 2**53-3, + 2**53-2, + 2**53-1, + 2**53, + 2**53+2, + 2**54-4, + 2**54-2, + 2**54, + 2**54+4] + for x in exact_values: + self.assertEqual(float(x), x) + self.assertEqual(float(-x), -x) + + # test round-half-even + for x, y in [(1, 0), (2, 2), (3, 4), (4, 4), (5, 4), (6, 6), (7, 8)]: + for p in range(15): + self.assertEqual(int(float(2**p*(2**53+x))), 2**p*(2**53+y)) + + for x, y in [(0, 0), (1, 0), (2, 0), (3, 4), (4, 4), (5, 4), (6, 8), + (7, 8), (8, 8), (9, 8), (10, 8), (11, 12), (12, 12), + (13, 12), (14, 16), (15, 16)]: + for p in range(15): + self.assertEqual(int(float(2**p*(2**54+x))), 2**p*(2**54+y)) + + # behaviour near extremes of floating-point range + int_dbl_max = int(DBL_MAX) + top_power = 2**DBL_MAX_EXP + halfway = (int_dbl_max + top_power)//2 + self.assertEqual(float(int_dbl_max), DBL_MAX) + self.assertEqual(float(int_dbl_max+1), DBL_MAX) + self.assertEqual(float(halfway-1), DBL_MAX) + self.assertRaises(OverflowError, float, halfway) + self.assertEqual(float(1-halfway), -DBL_MAX) + self.assertRaises(OverflowError, float, -halfway) + self.assertRaises(OverflowError, float, top_power-1) + self.assertRaises(OverflowError, float, top_power) + self.assertRaises(OverflowError, float, top_power+1) + self.assertRaises(OverflowError, float, 2*top_power-1) + self.assertRaises(OverflowError, float, 2*top_power) + self.assertRaises(OverflowError, float, top_power*top_power) + + for p in range(100): + x = 2**p * (2**53 + 1) + 1 + y = 2**p * (2**53 + 2) + self.assertEqual(int(float(x)), y) + + x = 2**p * (2**53 + 1) + y = 2**p * 2**53 + self.assertEqual(int(float(x)), y) + + # Compare builtin float conversion with pure Python int_to_float + # function above. + test_values = [ + int_dbl_max-1, int_dbl_max, int_dbl_max+1, + halfway-1, halfway, halfway + 1, + top_power-1, top_power, top_power+1, + 2*top_power-1, 2*top_power, top_power*top_power, + ] + test_values.extend(exact_values) + for p in range(-4, 8): + for x in range(-128, 128): + test_values.append(2**(p+53) + x) + for value in test_values: + self.check_float_conversion(value) + self.check_float_conversion(-value) + + def test_float_overflow(self): + for x in -2.0, -1.0, 0.0, 1.0, 2.0: + self.assertEqual(float(int(x)), x) + + shuge = '12345' * 120 + huge = 1 << 30000 + mhuge = -huge + namespace = {'huge': huge, 'mhuge': mhuge, 'shuge': shuge, 'math': math} + for test in ["float(huge)", "float(mhuge)", + "complex(huge)", "complex(mhuge)", + "complex(huge, 1)", "complex(mhuge, 1)", + "complex(1, huge)", "complex(1, mhuge)", + "1. + huge", "huge + 1.", "1. + mhuge", "mhuge + 1.", + "1. - huge", "huge - 1.", "1. - mhuge", "mhuge - 1.", + "1. * huge", "huge * 1.", "1. * mhuge", "mhuge * 1.", + "1. // huge", "huge // 1.", "1. // mhuge", "mhuge // 1.", + "1. / huge", "huge / 1.", "1. / mhuge", "mhuge / 1.", + "1. ** huge", "huge ** 1.", "1. ** mhuge", "mhuge ** 1.", + "math.sin(huge)", "math.sin(mhuge)", + "math.sqrt(huge)", "math.sqrt(mhuge)", # should do better + # math.floor() of an int returns an int now + ##"math.floor(huge)", "math.floor(mhuge)", + ]: + + self.assertRaises(OverflowError, eval, test, namespace) + + # XXX Perhaps float(shuge) can raise OverflowError on some box? + # The comparison should not. + self.assertNotEqual(float(shuge), int(shuge), + "float(shuge) should not equal int(shuge)") + + @unittest.expectedFailure # TODO RustPython + def test_logs(self): + LOG10E = math.log10(math.e) + + for exp in list(range(10)) + [100, 1000, 10000]: + value = 10 ** exp + log10 = math.log10(value) + self.assertAlmostEqual(log10, exp) + + # log10(value) == exp, so log(value) == log10(value)/log10(e) == + # exp/LOG10E + expected = exp / LOG10E + log = math.log(value) + self.assertAlmostEqual(log, expected) + + for bad in -(1 << 10000), -2, 0: + self.assertRaises(ValueError, math.log, bad) + self.assertRaises(ValueError, math.log10, bad) + + def test_mixed_compares(self): + eq = self.assertEqual + + # We're mostly concerned with that mixing floats and ints does the + # right stuff, even when ints are too large to fit in a float. + # The safest way to check the results is to use an entirely different + # method, which we do here via a skeletal rational class (which + # represents all Python ints and floats exactly). + class Rat: + def __init__(self, value): + if isinstance(value, int): + self.n = value + self.d = 1 + elif isinstance(value, float): + # Convert to exact rational equivalent. + f, e = math.frexp(abs(value)) + assert f == 0 or 0.5 <= f < 1.0 + # |value| = f * 2**e exactly + + # Suck up CHUNK bits at a time; 28 is enough so that we suck + # up all bits in 2 iterations for all known binary double- + # precision formats, and small enough to fit in an int. + CHUNK = 28 + top = 0 + # invariant: |value| = (top + f) * 2**e exactly + while f: + f = math.ldexp(f, CHUNK) + digit = int(f) + assert digit >> CHUNK == 0 + top = (top << CHUNK) | digit + f -= digit + assert 0.0 <= f < 1.0 + e -= CHUNK + + # Now |value| = top * 2**e exactly. + if e >= 0: + n = top << e + d = 1 + else: + n = top + d = 1 << -e + if value < 0: + n = -n + self.n = n + self.d = d + assert float(n) / float(d) == value + else: + raise TypeError("can't deal with %r" % value) + + def _cmp__(self, other): + if not isinstance(other, Rat): + other = Rat(other) + x, y = self.n * other.d, self.d * other.n + return (x > y) - (x < y) + def __eq__(self, other): + return self._cmp__(other) == 0 + def __ge__(self, other): + return self._cmp__(other) >= 0 + def __gt__(self, other): + return self._cmp__(other) > 0 + def __le__(self, other): + return self._cmp__(other) <= 0 + def __lt__(self, other): + return self._cmp__(other) < 0 + + cases = [0, 0.001, 0.99, 1.0, 1.5, 1e20, 1e200] + # 2**48 is an important boundary in the internals. 2**53 is an + # important boundary for IEEE double precision. + for t in 2.0**48, 2.0**50, 2.0**53: + cases.extend([t - 1.0, t - 0.3, t, t + 0.3, t + 1.0, + int(t-1), int(t), int(t+1)]) + cases.extend([0, 1, 2, sys.maxsize, float(sys.maxsize)]) + # 1 << 20000 should exceed all double formats. int(1e200) is to + # check that we get equality with 1e200 above. + t = int(1e200) + cases.extend([0, 1, 2, 1 << 20000, t-1, t, t+1]) + cases.extend([-x for x in cases]) + for x in cases: + Rx = Rat(x) + for y in cases: + Ry = Rat(y) + Rcmp = (Rx > Ry) - (Rx < Ry) + with self.subTest(x=x, y=y, Rcmp=Rcmp): + xycmp = (x > y) - (x < y) + eq(Rcmp, xycmp) + eq(x == y, Rcmp == 0) + eq(x != y, Rcmp != 0) + eq(x < y, Rcmp < 0) + eq(x <= y, Rcmp <= 0) + eq(x > y, Rcmp > 0) + eq(x >= y, Rcmp >= 0) + + @unittest.expectedFailure + def test__format__(self): + self.assertEqual(format(123456789, 'd'), '123456789') + self.assertEqual(format(123456789, 'd'), '123456789') + self.assertEqual(format(123456789, ','), '123,456,789') + self.assertEqual(format(123456789, '_'), '123_456_789') + + # sign and aligning are interdependent + self.assertEqual(format(1, "-"), '1') + self.assertEqual(format(-1, "-"), '-1') + self.assertEqual(format(1, "-3"), ' 1') + self.assertEqual(format(-1, "-3"), ' -1') + self.assertEqual(format(1, "+3"), ' +1') + self.assertEqual(format(-1, "+3"), ' -1') + self.assertEqual(format(1, " 3"), ' 1') + self.assertEqual(format(-1, " 3"), ' -1') + self.assertEqual(format(1, " "), ' 1') + self.assertEqual(format(-1, " "), '-1') + + # hex + self.assertEqual(format(3, "x"), "3") + self.assertEqual(format(3, "X"), "3") + self.assertEqual(format(1234, "x"), "4d2") + self.assertEqual(format(-1234, "x"), "-4d2") + self.assertEqual(format(1234, "8x"), " 4d2") + self.assertEqual(format(-1234, "8x"), " -4d2") + self.assertEqual(format(1234, "x"), "4d2") + self.assertEqual(format(-1234, "x"), "-4d2") + self.assertEqual(format(-3, "x"), "-3") + self.assertEqual(format(-3, "X"), "-3") + self.assertEqual(format(int('be', 16), "x"), "be") + self.assertEqual(format(int('be', 16), "X"), "BE") + self.assertEqual(format(-int('be', 16), "x"), "-be") + self.assertEqual(format(-int('be', 16), "X"), "-BE") + self.assertRaises(ValueError, format, 1234567890, ',x') + self.assertEqual(format(1234567890, '_x'), '4996_02d2') + self.assertEqual(format(1234567890, '_X'), '4996_02D2') + + # octal + self.assertEqual(format(3, "o"), "3") + self.assertEqual(format(-3, "o"), "-3") + self.assertEqual(format(1234, "o"), "2322") + self.assertEqual(format(-1234, "o"), "-2322") + self.assertEqual(format(1234, "-o"), "2322") + self.assertEqual(format(-1234, "-o"), "-2322") + self.assertEqual(format(1234, " o"), " 2322") + self.assertEqual(format(-1234, " o"), "-2322") + self.assertEqual(format(1234, "+o"), "+2322") + self.assertEqual(format(-1234, "+o"), "-2322") + self.assertRaises(ValueError, format, 1234567890, ',o') + self.assertEqual(format(1234567890, '_o'), '111_4540_1322') + + # binary + self.assertEqual(format(3, "b"), "11") + self.assertEqual(format(-3, "b"), "-11") + self.assertEqual(format(1234, "b"), "10011010010") + self.assertEqual(format(-1234, "b"), "-10011010010") + self.assertEqual(format(1234, "-b"), "10011010010") + self.assertEqual(format(-1234, "-b"), "-10011010010") + self.assertEqual(format(1234, " b"), " 10011010010") + self.assertEqual(format(-1234, " b"), "-10011010010") + self.assertEqual(format(1234, "+b"), "+10011010010") + self.assertEqual(format(-1234, "+b"), "-10011010010") + self.assertRaises(ValueError, format, 1234567890, ',b') + self.assertEqual(format(12345, '_b'), '11_0000_0011_1001') + + # make sure these are errors + self.assertRaises(ValueError, format, 3, "1.3") # precision disallowed + self.assertRaises(ValueError, format, 3, "_c") # underscore, + self.assertRaises(ValueError, format, 3, ",c") # comma, and + self.assertRaises(ValueError, format, 3, "+c") # sign not allowed + # with 'c' + + self.assertRaisesRegex(ValueError, 'Cannot specify both', format, 3, '_,') + self.assertRaisesRegex(ValueError, 'Cannot specify both', format, 3, ',_') + self.assertRaisesRegex(ValueError, 'Cannot specify both', format, 3, '_,d') + self.assertRaisesRegex(ValueError, 'Cannot specify both', format, 3, ',_d') + + self.assertRaisesRegex(ValueError, "Cannot specify ',' with 's'", format, 3, ',s') + self.assertRaisesRegex(ValueError, "Cannot specify '_' with 's'", format, 3, '_s') + + # ensure that only int and float type specifiers work + for format_spec in ([chr(x) for x in range(ord('a'), ord('z')+1)] + + [chr(x) for x in range(ord('A'), ord('Z')+1)]): + if not format_spec in 'bcdoxXeEfFgGn%': + self.assertRaises(ValueError, format, 0, format_spec) + self.assertRaises(ValueError, format, 1, format_spec) + self.assertRaises(ValueError, format, -1, format_spec) + self.assertRaises(ValueError, format, 2**100, format_spec) + self.assertRaises(ValueError, format, -(2**100), format_spec) + + # ensure that float type specifiers work; format converts + # the int to a float + for format_spec in 'eEfFgG%': + for value in [0, 1, -1, 100, -100, 1234567890, -1234567890]: + self.assertEqual(format(value, format_spec), + format(float(value), format_spec)) + + def test_nan_inf(self): + self.assertRaises(OverflowError, int, float('inf')) + self.assertRaises(OverflowError, int, float('-inf')) + self.assertRaises(ValueError, int, float('nan')) + + def test_mod_division(self): + with self.assertRaises(ZeroDivisionError): + _ = 1 % 0 + + self.assertEqual(13 % 10, 3) + self.assertEqual(-13 % 10, 7) + self.assertEqual(13 % -10, -7) + self.assertEqual(-13 % -10, -3) + + self.assertEqual(12 % 4, 0) + self.assertEqual(-12 % 4, 0) + self.assertEqual(12 % -4, 0) + self.assertEqual(-12 % -4, 0) + + def test_true_division(self): + huge = 1 << 40000 + mhuge = -huge + self.assertEqual(huge / huge, 1.0) + self.assertEqual(mhuge / mhuge, 1.0) + self.assertEqual(huge / mhuge, -1.0) + self.assertEqual(mhuge / huge, -1.0) + self.assertEqual(1 / huge, 0.0) + self.assertEqual(1 / huge, 0.0) + self.assertEqual(1 / mhuge, 0.0) + self.assertEqual(1 / mhuge, 0.0) + self.assertEqual((666 * huge + (huge >> 1)) / huge, 666.5) + self.assertEqual((666 * mhuge + (mhuge >> 1)) / mhuge, 666.5) + self.assertEqual((666 * huge + (huge >> 1)) / mhuge, -666.5) + self.assertEqual((666 * mhuge + (mhuge >> 1)) / huge, -666.5) + self.assertEqual(huge / (huge << 1), 0.5) + self.assertEqual((1000000 * huge) / huge, 1000000) + + namespace = {'huge': huge, 'mhuge': mhuge} + + for overflow in ["float(huge)", "float(mhuge)", + "huge / 1", "huge / 2", "huge / -1", "huge / -2", + "mhuge / 100", "mhuge / 200"]: + self.assertRaises(OverflowError, eval, overflow, namespace) + + for underflow in ["1 / huge", "2 / huge", "-1 / huge", "-2 / huge", + "100 / mhuge", "200 / mhuge"]: + result = eval(underflow, namespace) + self.assertEqual(result, 0.0, + "expected underflow to 0 from %r" % underflow) + + for zero in ["huge / 0", "mhuge / 0"]: + self.assertRaises(ZeroDivisionError, eval, zero, namespace) + + def test_floordiv(self): + with self.assertRaises(ZeroDivisionError): + _ = 1 // 0 + + self.assertEqual(2 // 3, 0) + self.assertEqual(2 // -3, -1) + self.assertEqual(-2 // 3, -1) + self.assertEqual(-2 // -3, 0) + + self.assertEqual(-11 // -3, 3) + self.assertEqual(-11 // 3, -4) + self.assertEqual(11 // -3, -4) + self.assertEqual(11 // 3, 3) + + self.assertEqual(-12 // -3, 4) + self.assertEqual(-12 // 3, -4) + self.assertEqual(12 // -3, -4) + self.assertEqual(12 // 3, 4) + + def check_truediv(self, a, b, skip_small=True): + """Verify that the result of a/b is correctly rounded, by + comparing it with a pure Python implementation of correctly + rounded division. b should be nonzero.""" + + # skip check for small a and b: in this case, the current + # implementation converts the arguments to float directly and + # then applies a float division. This can give doubly-rounded + # results on x87-using machines (particularly 32-bit Linux). + if skip_small and max(abs(a), abs(b)) < 2**DBL_MANT_DIG: + return + + try: + # use repr so that we can distinguish between -0.0 and 0.0 + expected = repr(truediv(a, b)) + except OverflowError: + expected = 'overflow' + except ZeroDivisionError: + expected = 'zerodivision' + + try: + got = repr(a / b) + except OverflowError: + got = 'overflow' + except ZeroDivisionError: + got = 'zerodivision' + + self.assertEqual(expected, got, "Incorrectly rounded division {}/{}: " + "expected {}, got {}".format(a, b, expected, got)) + + #@support.requires_IEEE_754 + @unittest.skip + def test_correctly_rounded_true_division(self): + # more stringent tests than those above, checking that the + # result of true division of ints is always correctly rounded. + # This test should probably be considered CPython-specific. + + # Exercise all the code paths not involving Gb-sized ints. + # ... divisions involving zero + self.check_truediv(123, 0) + self.check_truediv(-456, 0) + self.check_truediv(0, 3) + self.check_truediv(0, -3) + self.check_truediv(0, 0) + # ... overflow or underflow by large margin + self.check_truediv(671 * 12345 * 2**DBL_MAX_EXP, 12345) + self.check_truediv(12345, 345678 * 2**(DBL_MANT_DIG - DBL_MIN_EXP)) + # ... a much larger or smaller than b + self.check_truediv(12345*2**100, 98765) + self.check_truediv(12345*2**30, 98765*7**81) + # ... a / b near a boundary: one of 1, 2**DBL_MANT_DIG, 2**DBL_MIN_EXP, + # 2**DBL_MAX_EXP, 2**(DBL_MIN_EXP-DBL_MANT_DIG) + bases = (0, DBL_MANT_DIG, DBL_MIN_EXP, + DBL_MAX_EXP, DBL_MIN_EXP - DBL_MANT_DIG) + for base in bases: + for exp in range(base - 15, base + 15): + self.check_truediv(75312*2**max(exp, 0), 69187*2**max(-exp, 0)) + self.check_truediv(69187*2**max(exp, 0), 75312*2**max(-exp, 0)) + + # overflow corner case + for m in [1, 2, 7, 17, 12345, 7**100, + -1, -2, -5, -23, -67891, -41**50]: + for n in range(-10, 10): + self.check_truediv(m*DBL_MIN_OVERFLOW + n, m) + self.check_truediv(m*DBL_MIN_OVERFLOW + n, -m) + + # check detection of inexactness in shifting stage + for n in range(250): + # (2**DBL_MANT_DIG+1)/(2**DBL_MANT_DIG) lies halfway + # between two representable floats, and would usually be + # rounded down under round-half-to-even. The tiniest of + # additions to the numerator should cause it to be rounded + # up instead. + self.check_truediv((2**DBL_MANT_DIG + 1)*12345*2**200 + 2**n, + 2**DBL_MANT_DIG*12345) + + # 1/2731 is one of the smallest division cases that's subject + # to double rounding on IEEE 754 machines working internally with + # 64-bit precision. On such machines, the next check would fail, + # were it not explicitly skipped in check_truediv. + self.check_truediv(1, 2731) + + # a particularly bad case for the old algorithm: gives an + # error of close to 3.5 ulps. + self.check_truediv(295147931372582273023, 295147932265116303360) + for i in range(1000): + self.check_truediv(10**(i+1), 10**i) + self.check_truediv(10**i, 10**(i+1)) + + # test round-half-to-even behaviour, normal result + for m in [1, 2, 4, 7, 8, 16, 17, 32, 12345, 7**100, + -1, -2, -5, -23, -67891, -41**50]: + for n in range(-10, 10): + self.check_truediv(2**DBL_MANT_DIG*m + n, m) + + # test round-half-to-even, subnormal result + for n in range(-20, 20): + self.check_truediv(n, 2**1076) + + # largeish random divisions: a/b where |a| <= |b| <= + # 2*|a|; |ans| is between 0.5 and 1.0, so error should + # always be bounded by 2**-54 with equality possible only + # if the least significant bit of q=ans*2**53 is zero. + for M in [10**10, 10**100, 10**1000]: + for i in range(1000): + a = random.randrange(1, M) + b = random.randrange(a, 2*a+1) + self.check_truediv(a, b) + self.check_truediv(-a, b) + self.check_truediv(a, -b) + self.check_truediv(-a, -b) + + # and some (genuinely) random tests + for _ in range(10000): + a_bits = random.randrange(1000) + b_bits = random.randrange(1, 1000) + x = random.randrange(2**a_bits) + y = random.randrange(1, 2**b_bits) + self.check_truediv(x, y) + self.check_truediv(x, -y) + self.check_truediv(-x, y) + self.check_truediv(-x, -y) + + def test_negative_shift_count(self): + with self.assertRaises(ValueError): + 42 << -3 + with self.assertRaises(ValueError): + 42 << -(1 << 1000) + with self.assertRaises(ValueError): + 42 >> -3 + with self.assertRaises(ValueError): + 42 >> -(1 << 1000) + + @unittest.expectedFailure # TODO RustPython + def test_lshift_of_zero(self): + self.assertEqual(0 << 0, 0) + self.assertEqual(0 << 10, 0) + with self.assertRaises(ValueError): + 0 << -1 + self.assertEqual(0 << (1 << 1000), 0) + with self.assertRaises(ValueError): + 0 << -(1 << 1000) + + @support.cpython_only + def test_huge_lshift_of_zero(self): + # Shouldn't try to allocate memory for a huge shift. See issue #27870. + # Other implementations may have a different boundary for overflow, + # or not raise at all. + self.assertEqual(0 << sys.maxsize, 0) + self.assertEqual(0 << (sys.maxsize + 1), 0) + + @support.cpython_only + @support.bigmemtest(sys.maxsize + 1000, memuse=2/15 * 2, dry_run=False) + def test_huge_lshift(self, size): + self.assertEqual(1 << (sys.maxsize + 1000), 1 << 1000 << sys.maxsize) + + @unittest.expectedFailure # TODO RustPytohn + def test_huge_rshift(self): + self.assertEqual(42 >> (1 << 1000), 0) + self.assertEqual((-42) >> (1 << 1000), -1) + + @support.cpython_only + @support.bigmemtest(sys.maxsize + 500, memuse=2/15, dry_run=False) + def test_huge_rshift_of_huge(self, size): + huge = ((1 << 500) + 11) << sys.maxsize + self.assertEqual(huge >> (sys.maxsize + 1), (1 << 499) + 5) + self.assertEqual(huge >> (sys.maxsize + 1000), 0) + + @support.cpython_only + def test_small_ints_in_huge_calculation(self): + a = 2 ** 100 + b = -a + 1 + c = a + 1 + self.assertIs(a + b, 1) + self.assertIs(c - a, 1) + + @unittest.expectedFailure + def test_small_ints(self): + for i in range(-5, 257): + self.assertIs(i, i + 0) + self.assertIs(i, i * 1) + self.assertIs(i, i - 0) + self.assertIs(i, i // 1) + self.assertIs(i, i & -1) + self.assertIs(i, i | 0) + self.assertIs(i, i ^ 0) + self.assertIs(i, ~~i) + self.assertIs(i, i**1) + self.assertIs(i, int(str(i))) + self.assertIs(i, i<<2>>2, str(i)) + # corner cases + i = 1 << 70 + self.assertIs(i - i, 0) + self.assertIs(0 * i, 0) + + def test_bit_length(self): + tiny = 1e-10 + for x in range(-65000, 65000): + k = x.bit_length() + # Check equivalence with Python version + self.assertEqual(k, len(bin(x).lstrip('-0b'))) + # Behaviour as specified in the docs + if x != 0: + self.assertTrue(2**(k-1) <= abs(x) < 2**k) + else: + self.assertEqual(k, 0) + # Alternative definition: x.bit_length() == 1 + floor(log_2(x)) + if x != 0: + # When x is an exact power of 2, numeric errors can + # cause floor(log(x)/log(2)) to be one too small; for + # small x this can be fixed by adding a small quantity + # to the quotient before taking the floor. + self.assertEqual(k, 1 + math.floor( + math.log(abs(x))/math.log(2) + tiny)) + + self.assertEqual((0).bit_length(), 0) + self.assertEqual((1).bit_length(), 1) + self.assertEqual((-1).bit_length(), 1) + self.assertEqual((2).bit_length(), 2) + self.assertEqual((-2).bit_length(), 2) + for i in [2, 3, 15, 16, 17, 31, 32, 33, 63, 64, 234]: + a = 2**i + self.assertEqual((a-1).bit_length(), i) + self.assertEqual((1-a).bit_length(), i) + self.assertEqual((a).bit_length(), i+1) + self.assertEqual((-a).bit_length(), i+1) + self.assertEqual((a+1).bit_length(), i+1) + self.assertEqual((-a-1).bit_length(), i+1) + + def test_bit_count(self): + for a in range(-1000, 1000): + self.assertEqual(a.bit_count(), bin(a).count("1")) + + for exp in [10, 17, 63, 64, 65, 1009, 70234, 1234567]: + a = 2**exp + self.assertEqual(a.bit_count(), 1) + self.assertEqual((a - 1).bit_count(), exp) + self.assertEqual((a ^ 63).bit_count(), 7) + self.assertEqual(((a - 1) ^ 510).bit_count(), exp - 8) + + @unittest.expectedFailure + def test_round(self): + # check round-half-even algorithm. For round to nearest ten; + # rounding map is invariant under adding multiples of 20 + test_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, + 6:10, 7:10, 8:10, 9:10, 10:10, 11:10, 12:10, 13:10, 14:10, + 15:20, 16:20, 17:20, 18:20, 19:20} + for offset in range(-520, 520, 20): + for k, v in test_dict.items(): + got = round(k+offset, -1) + expected = v+offset + self.assertEqual(got, expected) + self.assertIs(type(got), int) + + # larger second argument + self.assertEqual(round(-150, -2), -200) + self.assertEqual(round(-149, -2), -100) + self.assertEqual(round(-51, -2), -100) + self.assertEqual(round(-50, -2), 0) + self.assertEqual(round(-49, -2), 0) + self.assertEqual(round(-1, -2), 0) + self.assertEqual(round(0, -2), 0) + self.assertEqual(round(1, -2), 0) + self.assertEqual(round(49, -2), 0) + self.assertEqual(round(50, -2), 0) + self.assertEqual(round(51, -2), 100) + self.assertEqual(round(149, -2), 100) + self.assertEqual(round(150, -2), 200) + self.assertEqual(round(250, -2), 200) + self.assertEqual(round(251, -2), 300) + self.assertEqual(round(172500, -3), 172000) + self.assertEqual(round(173500, -3), 174000) + self.assertEqual(round(31415926535, -1), 31415926540) + self.assertEqual(round(31415926535, -2), 31415926500) + self.assertEqual(round(31415926535, -3), 31415927000) + self.assertEqual(round(31415926535, -4), 31415930000) + self.assertEqual(round(31415926535, -5), 31415900000) + self.assertEqual(round(31415926535, -6), 31416000000) + self.assertEqual(round(31415926535, -7), 31420000000) + self.assertEqual(round(31415926535, -8), 31400000000) + self.assertEqual(round(31415926535, -9), 31000000000) + self.assertEqual(round(31415926535, -10), 30000000000) + self.assertEqual(round(31415926535, -11), 0) + self.assertEqual(round(31415926535, -12), 0) + self.assertEqual(round(31415926535, -999), 0) + + # should get correct results even for huge inputs + for k in range(10, 100): + got = round(10**k + 324678, -3) + expect = 10**k + 325000 + self.assertEqual(got, expect) + self.assertIs(type(got), int) + + # nonnegative second argument: round(x, n) should just return x + for n in range(5): + for i in range(100): + x = random.randrange(-10000, 10000) + got = round(x, n) + self.assertEqual(got, x) + self.assertIs(type(got), int) + for huge_n in 2**31-1, 2**31, 2**63-1, 2**63, 2**100, 10**100: + self.assertEqual(round(8979323, huge_n), 8979323) + + # omitted second argument + for i in range(100): + x = random.randrange(-10000, 10000) + got = round(x) + self.assertEqual(got, x) + self.assertIs(type(got), int) + + # bad second argument + bad_exponents = ('brian', 2.0, 0j) + for e in bad_exponents: + self.assertRaises(TypeError, round, 3, e) + + @unittest.expectedFailure # TODO RustPython + def test_to_bytes(self): + def check(tests, byteorder, signed=False): + for test, expected in tests.items(): + try: + self.assertEqual( + test.to_bytes(len(expected), byteorder, signed=signed), + expected) + except Exception as err: + raise AssertionError( + "failed to convert {0} with byteorder={1} and signed={2}" + .format(test, byteorder, signed)) from err + + # Convert integers to signed big-endian byte arrays. + tests1 = { + 0: b'\x00', + 1: b'\x01', + -1: b'\xff', + -127: b'\x81', + -128: b'\x80', + -129: b'\xff\x7f', + 127: b'\x7f', + 129: b'\x00\x81', + -255: b'\xff\x01', + -256: b'\xff\x00', + 255: b'\x00\xff', + 256: b'\x01\x00', + 32767: b'\x7f\xff', + -32768: b'\xff\x80\x00', + 65535: b'\x00\xff\xff', + -65536: b'\xff\x00\x00', + -8388608: b'\x80\x00\x00' + } + check(tests1, 'big', signed=True) + + # Convert integers to signed little-endian byte arrays. + tests2 = { + 0: b'\x00', + 1: b'\x01', + -1: b'\xff', + -127: b'\x81', + -128: b'\x80', + -129: b'\x7f\xff', + 127: b'\x7f', + 129: b'\x81\x00', + -255: b'\x01\xff', + -256: b'\x00\xff', + 255: b'\xff\x00', + 256: b'\x00\x01', + 32767: b'\xff\x7f', + -32768: b'\x00\x80', + 65535: b'\xff\xff\x00', + -65536: b'\x00\x00\xff', + -8388608: b'\x00\x00\x80' + } + check(tests2, 'little', signed=True) + + # Convert integers to unsigned big-endian byte arrays. + tests3 = { + 0: b'\x00', + 1: b'\x01', + 127: b'\x7f', + 128: b'\x80', + 255: b'\xff', + 256: b'\x01\x00', + 32767: b'\x7f\xff', + 32768: b'\x80\x00', + 65535: b'\xff\xff', + 65536: b'\x01\x00\x00' + } + check(tests3, 'big', signed=False) + + # Convert integers to unsigned little-endian byte arrays. + tests4 = { + 0: b'\x00', + 1: b'\x01', + 127: b'\x7f', + 128: b'\x80', + 255: b'\xff', + 256: b'\x00\x01', + 32767: b'\xff\x7f', + 32768: b'\x00\x80', + 65535: b'\xff\xff', + 65536: b'\x00\x00\x01' + } + check(tests4, 'little', signed=False) + + self.assertRaises(OverflowError, (256).to_bytes, 1, 'big', signed=False) + self.assertRaises(OverflowError, (256).to_bytes, 1, 'big', signed=True) + self.assertRaises(OverflowError, (256).to_bytes, 1, 'little', signed=False) + self.assertRaises(OverflowError, (256).to_bytes, 1, 'little', signed=True) + self.assertRaises(OverflowError, (-1).to_bytes, 2, 'big', signed=False) + self.assertRaises(OverflowError, (-1).to_bytes, 2, 'little', signed=False) + self.assertEqual((0).to_bytes(0, 'big'), b'') + self.assertEqual((1).to_bytes(5, 'big'), b'\x00\x00\x00\x00\x01') + self.assertEqual((0).to_bytes(5, 'big'), b'\x00\x00\x00\x00\x00') + self.assertEqual((-1).to_bytes(5, 'big', signed=True), + b'\xff\xff\xff\xff\xff') + self.assertRaises(OverflowError, (1).to_bytes, 0, 'big') + + @unittest.expectedFailure + def test_from_bytes(self): + def check(tests, byteorder, signed=False): + for test, expected in tests.items(): + try: + self.assertEqual( + int.from_bytes(test, byteorder, signed=signed), + expected) + except Exception as err: + raise AssertionError( + "failed to convert {0} with byteorder={1!r} and signed={2}" + .format(test, byteorder, signed)) from err + + # Convert signed big-endian byte arrays to integers. + tests1 = { + b'': 0, + b'\x00': 0, + b'\x00\x00': 0, + b'\x01': 1, + b'\x00\x01': 1, + b'\xff': -1, + b'\xff\xff': -1, + b'\x81': -127, + b'\x80': -128, + b'\xff\x7f': -129, + b'\x7f': 127, + b'\x00\x81': 129, + b'\xff\x01': -255, + b'\xff\x00': -256, + b'\x00\xff': 255, + b'\x01\x00': 256, + b'\x7f\xff': 32767, + b'\x80\x00': -32768, + b'\x00\xff\xff': 65535, + b'\xff\x00\x00': -65536, + b'\x80\x00\x00': -8388608 + } + check(tests1, 'big', signed=True) + + # Convert signed little-endian byte arrays to integers. + tests2 = { + b'': 0, + b'\x00': 0, + b'\x00\x00': 0, + b'\x01': 1, + b'\x00\x01': 256, + b'\xff': -1, + b'\xff\xff': -1, + b'\x81': -127, + b'\x80': -128, + b'\x7f\xff': -129, + b'\x7f': 127, + b'\x81\x00': 129, + b'\x01\xff': -255, + b'\x00\xff': -256, + b'\xff\x00': 255, + b'\x00\x01': 256, + b'\xff\x7f': 32767, + b'\x00\x80': -32768, + b'\xff\xff\x00': 65535, + b'\x00\x00\xff': -65536, + b'\x00\x00\x80': -8388608 + } + check(tests2, 'little', signed=True) + + # Convert unsigned big-endian byte arrays to integers. + tests3 = { + b'': 0, + b'\x00': 0, + b'\x01': 1, + b'\x7f': 127, + b'\x80': 128, + b'\xff': 255, + b'\x01\x00': 256, + b'\x7f\xff': 32767, + b'\x80\x00': 32768, + b'\xff\xff': 65535, + b'\x01\x00\x00': 65536, + } + check(tests3, 'big', signed=False) + + # Convert integers to unsigned little-endian byte arrays. + tests4 = { + b'': 0, + b'\x00': 0, + b'\x01': 1, + b'\x7f': 127, + b'\x80': 128, + b'\xff': 255, + b'\x00\x01': 256, + b'\xff\x7f': 32767, + b'\x00\x80': 32768, + b'\xff\xff': 65535, + b'\x00\x00\x01': 65536, + } + check(tests4, 'little', signed=False) + + class myint(int): + pass + + self.assertIs(type(myint.from_bytes(b'\x00', 'big')), myint) + self.assertEqual(myint.from_bytes(b'\x01', 'big'), 1) + self.assertIs( + type(myint.from_bytes(b'\x00', 'big', signed=False)), myint) + self.assertEqual(myint.from_bytes(b'\x01', 'big', signed=False), 1) + self.assertIs(type(myint.from_bytes(b'\x00', 'little')), myint) + self.assertEqual(myint.from_bytes(b'\x01', 'little'), 1) + self.assertIs(type(myint.from_bytes( + b'\x00', 'little', signed=False)), myint) + self.assertEqual(myint.from_bytes(b'\x01', 'little', signed=False), 1) + self.assertEqual( + int.from_bytes([255, 0, 0], 'big', signed=True), -65536) + self.assertEqual( + int.from_bytes((255, 0, 0), 'big', signed=True), -65536) + self.assertEqual(int.from_bytes( + bytearray(b'\xff\x00\x00'), 'big', signed=True), -65536) + self.assertEqual(int.from_bytes( + bytearray(b'\xff\x00\x00'), 'big', signed=True), -65536) + self.assertEqual(int.from_bytes( + array.array('B', b'\xff\x00\x00'), 'big', signed=True), -65536) + self.assertEqual(int.from_bytes( + memoryview(b'\xff\x00\x00'), 'big', signed=True), -65536) + self.assertRaises(ValueError, int.from_bytes, [256], 'big') + self.assertRaises(ValueError, int.from_bytes, [0], 'big\x00') + self.assertRaises(ValueError, int.from_bytes, [0], 'little\x00') + self.assertRaises(TypeError, int.from_bytes, "", 'big') + self.assertRaises(TypeError, int.from_bytes, "\x00", 'big') + self.assertRaises(TypeError, int.from_bytes, 0, 'big') + self.assertRaises(TypeError, int.from_bytes, 0, 'big', True) + self.assertRaises(TypeError, myint.from_bytes, "", 'big') + self.assertRaises(TypeError, myint.from_bytes, "\x00", 'big') + self.assertRaises(TypeError, myint.from_bytes, 0, 'big') + self.assertRaises(TypeError, int.from_bytes, 0, 'big', True) + + class myint2(int): + def __new__(cls, value): + return int.__new__(cls, value + 1) + + i = myint2.from_bytes(b'\x01', 'big') + self.assertIs(type(i), myint2) + self.assertEqual(i, 2) + + class myint3(int): + def __init__(self, value): + self.foo = 'bar' + + i = myint3.from_bytes(b'\x01', 'big') + self.assertIs(type(i), myint3) + self.assertEqual(i, 1) + self.assertEqual(getattr(i, 'foo', 'none'), 'bar') + + def test_access_to_nonexistent_digit_0(self): + # http://bugs.python.org/issue14630: A bug in _PyLong_Copy meant that + # ob_digit[0] was being incorrectly accessed for instances of a + # subclass of int, with value 0. + class Integer(int): + def __new__(cls, value=0): + self = int.__new__(cls, value) + self.foo = 'foo' + return self + + integers = [Integer(0) for i in range(1000)] + for n in map(int, integers): + self.assertEqual(n, 0) + + def test_shift_bool(self): + # Issue #21422: ensure that bool << int and bool >> int return int + for value in (True, False): + for shift in (0, 2): + self.assertEqual(type(value << shift), int) + self.assertEqual(type(value >> shift), int) + + def test_as_integer_ratio(self): + class myint(int): + pass + tests = [10, 0, -10, 1, sys.maxsize + 1, True, False, myint(42)] + for value in tests: + numerator, denominator = value.as_integer_ratio() + self.assertEqual((numerator, denominator), (int(value), 1)) + self.assertEqual(type(numerator), int) + self.assertEqual(type(denominator), int) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py new file mode 100644 index 0000000000..1abdac73b5 --- /dev/null +++ b/Lib/test/test_math.py @@ -0,0 +1,2190 @@ + +# Python test set -- math module +# XXXX Should not do tests around zero only + +from test.support import run_unittest, verbose#, requires_IEEE_754 # TODO: RUSTPYTHON, commented due to import error +from test import support +import unittest +import itertools +import decimal +import math +import os +import platform +import random +import struct +import sys + + +eps = 1E-05 +NAN = float('nan') +INF = float('inf') +NINF = float('-inf') + +# TODO RustPython: float_info is so far not supported -> hard code for the moment +# FLOAT_MAX = sys.float_info.max +# FLOAT_MIN = sys.float_info.min +FLOAT_MAX = 1.7976931348623157e+308 +FLOAT_MIN = 2.2250738585072014e-308 + +# detect evidence of double-rounding: fsum is not always correctly +# rounded on machines that suffer from double rounding. +x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer +HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4) + +# locate file with test values +if __name__ == '__main__': + file = sys.argv[0] +else: + file = __file__ +test_dir = os.path.dirname(file) or os.curdir +math_testcases = os.path.join(test_dir, 'math_testcases.txt') +test_file = os.path.join(test_dir, 'cmath_testcases.txt') + + +def to_ulps(x): + """Convert a non-NaN float x to an integer, in such a way that + adjacent floats are converted to adjacent integers. Then + abs(ulps(x) - ulps(y)) gives the difference in ulps between two + floats. + The results from this function will only make sense on platforms + where native doubles are represented in IEEE 754 binary64 format. + Note: 0.0 and -0.0 are converted to 0 and -1, respectively. + """ + n = struct.unpack('= 0} product_{0 < j <= n >> i; j odd} j +# +# The outer product above is an infinite product, but once i >= n.bit_length, +# (n >> i) < 1 and the corresponding term of the product is empty. So only the +# finitely many terms for 0 <= i < n.bit_length() contribute anything. +# +# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner +# product in the formula above starts at 1 for i == n.bit_length(); for each i +# < n.bit_length() we get the inner product for i from that for i + 1 by +# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms, +# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2). + +def count_set_bits(n): + """Number of '1' bits in binary expansion of a nonnnegative integer.""" + return 1 + count_set_bits(n & n - 1) if n else 0 + +def partial_product(start, stop): + """Product of integers in range(start, stop, 2), computed recursively. + start and stop should both be odd, with start <= stop. + """ + numfactors = (stop - start) >> 1 + if not numfactors: + return 1 + elif numfactors == 1: + return start + else: + mid = (start + numfactors) | 1 + return partial_product(start, mid) * partial_product(mid, stop) + +def py_factorial(n): + """Factorial of nonnegative integer n, via "Binary Split Factorial Formula" + described at http://www.luschny.de/math/factorial/binarysplitfact.html + """ + inner = outer = 1 + for i in reversed(range(n.bit_length())): + inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1) + outer *= inner + return outer << (n - count_set_bits(n)) + +def ulp_abs_check(expected, got, ulp_tol, abs_tol): + """Given finite floats `expected` and `got`, check that they're + approximately equal to within the given number of ulps or the + given absolute tolerance, whichever is bigger. + Returns None on success and an error message on failure. + """ + ulp_error = abs(to_ulps(expected) - to_ulps(got)) + abs_error = abs(expected - got) + + # Succeed if either abs_error <= abs_tol or ulp_error <= ulp_tol. + if abs_error <= abs_tol or ulp_error <= ulp_tol: + return None + else: + fmt = ("error = {:.3g} ({:d} ulps); " + "permitted error = {:.3g} or {:d} ulps") + return fmt.format(abs_error, ulp_error, abs_tol, ulp_tol) + +def parse_mtestfile(fname): + """Parse a file with test values + -- starts a comment + blank lines, or lines containing only a comment, are ignored + other lines are expected to have the form + id fn arg -> expected [flag]* + """ + with open(fname) as fp: + for line in fp: + # strip comments, and skip blank lines + if '--' in line: + line = line[:line.index('--')] + if not line.strip(): + continue + + lhs, rhs = line.split('->') + id, fn, arg = lhs.split() + rhs_pieces = rhs.split() + exp = rhs_pieces[0] + flags = rhs_pieces[1:] + + yield (id, fn, float(arg), float(exp), flags) + + +def parse_testfile(fname): + """Parse a file with test values + Empty lines or lines starting with -- are ignored + yields id, fn, arg_real, arg_imag, exp_real, exp_imag + """ + with open(fname) as fp: + for line in fp: + # skip comment lines and blank lines + if line.startswith('--') or not line.strip(): + continue + + lhs, rhs = line.split('->') + id, fn, arg_real, arg_imag = lhs.split() + rhs_pieces = rhs.split() + exp_real, exp_imag = rhs_pieces[0], rhs_pieces[1] + flags = rhs_pieces[2:] + + yield (id, fn, + float(arg_real), float(arg_imag), + float(exp_real), float(exp_imag), + flags) + + +def result_check(expected, got, ulp_tol=5, abs_tol=0.0): + # Common logic of MathTests.(ftest, test_testcases, test_mtestcases) + """Compare arguments expected and got, as floats, if either + is a float, using a tolerance expressed in multiples of + ulp(expected) or absolutely (if given and greater). + As a convenience, when neither argument is a float, and for + non-finite floats, exact equality is demanded. Also, nan==nan + as far as this function is concerned. + Returns None on success and an error message on failure. + """ + + # Check exactly equal (applies also to strings representing exceptions) + if got == expected: + return None + + failure = "not equal" + + # Turn mixed float and int comparison (e.g. floor()) to all-float + if isinstance(expected, float) and isinstance(got, int): + got = float(got) + elif isinstance(got, float) and isinstance(expected, int): + expected = float(expected) + + if isinstance(expected, float) and isinstance(got, float): + if math.isnan(expected) and math.isnan(got): + # Pass, since both nan + failure = None + elif math.isinf(expected) or math.isinf(got): + # We already know they're not equal, drop through to failure + pass + else: + # Both are finite floats (now). Are they close enough? + failure = ulp_abs_check(expected, got, ulp_tol, abs_tol) + + # arguments are not equal, and if numeric, are too far apart + if failure is not None: + fail_fmt = "expected {!r}, got {!r}" + fail_msg = fail_fmt.format(expected, got) + fail_msg += ' ({})'.format(failure) + return fail_msg + else: + return None + +class FloatLike: + def __init__(self, value): + self.value = value + + def __float__(self): + return self.value + +class IntSubclass(int): + pass + +# Class providing an __index__ method. +class MyIndexable(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + +class MathTests(unittest.TestCase): + + def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0): + """Compare arguments expected and got, as floats, if either + is a float, using a tolerance expressed in multiples of + ulp(expected) or absolutely, whichever is greater. + As a convenience, when neither argument is a float, and for + non-finite floats, exact equality is demanded. Also, nan==nan + in this function. + """ + failure = result_check(expected, got, ulp_tol, abs_tol) + if failure is not None: + self.fail("{}: {}".format(name, failure)) + + def testConstants(self): + # Ref: Abramowitz & Stegun (Dover, 1965) + self.ftest('pi', math.pi, 3.141592653589793238462643) + self.ftest('e', math.e, 2.718281828459045235360287) + self.assertEqual(math.tau, 2*math.pi) + + @unittest.skip('TODO: RustPython') + def testAcos(self): + self.assertRaises(TypeError, math.acos) + self.ftest('acos(-1)', math.acos(-1), math.pi) + self.ftest('acos(0)', math.acos(0), math.pi/2) + self.ftest('acos(1)', math.acos(1), 0) + self.assertRaises(ValueError, math.acos, INF) + self.assertRaises(ValueError, math.acos, NINF) + self.assertRaises(ValueError, math.acos, 1 + eps) + self.assertRaises(ValueError, math.acos, -1 - eps) + self.assertTrue(math.isnan(math.acos(NAN))) + + @unittest.skip('TODO: RustPython') + def testAcosh(self): + self.assertRaises(TypeError, math.acosh) + self.ftest('acosh(1)', math.acosh(1), 0) + self.ftest('acosh(2)', math.acosh(2), 1.3169578969248168) + self.assertRaises(ValueError, math.acosh, 0) + self.assertRaises(ValueError, math.acosh, -1) + self.assertEqual(math.acosh(INF), INF) + self.assertRaises(ValueError, math.acosh, NINF) + self.assertTrue(math.isnan(math.acosh(NAN))) + + @unittest.skip('TODO: RustPython') + def testAsin(self): + self.assertRaises(TypeError, math.asin) + self.ftest('asin(-1)', math.asin(-1), -math.pi/2) + self.ftest('asin(0)', math.asin(0), 0) + self.ftest('asin(1)', math.asin(1), math.pi/2) + self.assertRaises(ValueError, math.asin, INF) + self.assertRaises(ValueError, math.asin, NINF) + self.assertRaises(ValueError, math.asin, 1 + eps) + self.assertRaises(ValueError, math.asin, -1 - eps) + self.assertTrue(math.isnan(math.asin(NAN))) + + def testAsinh(self): + self.assertRaises(TypeError, math.asinh) + self.ftest('asinh(0)', math.asinh(0), 0) + self.ftest('asinh(1)', math.asinh(1), 0.88137358701954305) + self.ftest('asinh(-1)', math.asinh(-1), -0.88137358701954305) + self.assertEqual(math.asinh(INF), INF) + self.assertEqual(math.asinh(NINF), NINF) + self.assertTrue(math.isnan(math.asinh(NAN))) + + def testAtan(self): + self.assertRaises(TypeError, math.atan) + self.ftest('atan(-1)', math.atan(-1), -math.pi/4) + self.ftest('atan(0)', math.atan(0), 0) + self.ftest('atan(1)', math.atan(1), math.pi/4) + self.ftest('atan(inf)', math.atan(INF), math.pi/2) + self.ftest('atan(-inf)', math.atan(NINF), -math.pi/2) + self.assertTrue(math.isnan(math.atan(NAN))) + + @unittest.skip('TODO: RustPython') + def testAtanh(self): + self.assertRaises(TypeError, math.atan) + self.ftest('atanh(0)', math.atanh(0), 0) + self.ftest('atanh(0.5)', math.atanh(0.5), 0.54930614433405489) + self.ftest('atanh(-0.5)', math.atanh(-0.5), -0.54930614433405489) + self.assertRaises(ValueError, math.atanh, 1) + self.assertRaises(ValueError, math.atanh, -1) + self.assertRaises(ValueError, math.atanh, INF) + self.assertRaises(ValueError, math.atanh, NINF) + self.assertTrue(math.isnan(math.atanh(NAN))) + + def testAtan2(self): + self.assertRaises(TypeError, math.atan2) + self.ftest('atan2(-1, 0)', math.atan2(-1, 0), -math.pi/2) + self.ftest('atan2(-1, 1)', math.atan2(-1, 1), -math.pi/4) + self.ftest('atan2(0, 1)', math.atan2(0, 1), 0) + self.ftest('atan2(1, 1)', math.atan2(1, 1), math.pi/4) + self.ftest('atan2(1, 0)', math.atan2(1, 0), math.pi/2) + + # math.atan2(0, x) + self.ftest('atan2(0., -inf)', math.atan2(0., NINF), math.pi) + self.ftest('atan2(0., -2.3)', math.atan2(0., -2.3), math.pi) + self.ftest('atan2(0., -0.)', math.atan2(0., -0.), math.pi) + self.assertEqual(math.atan2(0., 0.), 0.) + self.assertEqual(math.atan2(0., 2.3), 0.) + self.assertEqual(math.atan2(0., INF), 0.) + self.assertTrue(math.isnan(math.atan2(0., NAN))) + # math.atan2(-0, x) + self.ftest('atan2(-0., -inf)', math.atan2(-0., NINF), -math.pi) + self.ftest('atan2(-0., -2.3)', math.atan2(-0., -2.3), -math.pi) + self.ftest('atan2(-0., -0.)', math.atan2(-0., -0.), -math.pi) + self.assertEqual(math.atan2(-0., 0.), -0.) + self.assertEqual(math.atan2(-0., 2.3), -0.) + self.assertEqual(math.atan2(-0., INF), -0.) + self.assertTrue(math.isnan(math.atan2(-0., NAN))) + # math.atan2(INF, x) + self.ftest('atan2(inf, -inf)', math.atan2(INF, NINF), math.pi*3/4) + self.ftest('atan2(inf, -2.3)', math.atan2(INF, -2.3), math.pi/2) + self.ftest('atan2(inf, -0.)', math.atan2(INF, -0.0), math.pi/2) + self.ftest('atan2(inf, 0.)', math.atan2(INF, 0.0), math.pi/2) + self.ftest('atan2(inf, 2.3)', math.atan2(INF, 2.3), math.pi/2) + self.ftest('atan2(inf, inf)', math.atan2(INF, INF), math.pi/4) + self.assertTrue(math.isnan(math.atan2(INF, NAN))) + # math.atan2(NINF, x) + self.ftest('atan2(-inf, -inf)', math.atan2(NINF, NINF), -math.pi*3/4) + self.ftest('atan2(-inf, -2.3)', math.atan2(NINF, -2.3), -math.pi/2) + self.ftest('atan2(-inf, -0.)', math.atan2(NINF, -0.0), -math.pi/2) + self.ftest('atan2(-inf, 0.)', math.atan2(NINF, 0.0), -math.pi/2) + self.ftest('atan2(-inf, 2.3)', math.atan2(NINF, 2.3), -math.pi/2) + self.ftest('atan2(-inf, inf)', math.atan2(NINF, INF), -math.pi/4) + self.assertTrue(math.isnan(math.atan2(NINF, NAN))) + # math.atan2(+finite, x) + self.ftest('atan2(2.3, -inf)', math.atan2(2.3, NINF), math.pi) + self.ftest('atan2(2.3, -0.)', math.atan2(2.3, -0.), math.pi/2) + self.ftest('atan2(2.3, 0.)', math.atan2(2.3, 0.), math.pi/2) + self.assertEqual(math.atan2(2.3, INF), 0.) + self.assertTrue(math.isnan(math.atan2(2.3, NAN))) + # math.atan2(-finite, x) + self.ftest('atan2(-2.3, -inf)', math.atan2(-2.3, NINF), -math.pi) + self.ftest('atan2(-2.3, -0.)', math.atan2(-2.3, -0.), -math.pi/2) + self.ftest('atan2(-2.3, 0.)', math.atan2(-2.3, 0.), -math.pi/2) + self.assertEqual(math.atan2(-2.3, INF), -0.) + self.assertTrue(math.isnan(math.atan2(-2.3, NAN))) + # math.atan2(NAN, x) + self.assertTrue(math.isnan(math.atan2(NAN, NINF))) + self.assertTrue(math.isnan(math.atan2(NAN, -2.3))) + self.assertTrue(math.isnan(math.atan2(NAN, -0.))) + self.assertTrue(math.isnan(math.atan2(NAN, 0.))) + self.assertTrue(math.isnan(math.atan2(NAN, 2.3))) + self.assertTrue(math.isnan(math.atan2(NAN, INF))) + self.assertTrue(math.isnan(math.atan2(NAN, NAN))) + + @unittest.skip('TODO: RustPython') + def testCeil(self): + self.assertRaises(TypeError, math.ceil) + self.assertEqual(int, type(math.ceil(0.5))) + self.assertEqual(math.ceil(0.5), 1) + self.assertEqual(math.ceil(1.0), 1) + self.assertEqual(math.ceil(1.5), 2) + self.assertEqual(math.ceil(-0.5), 0) + self.assertEqual(math.ceil(-1.0), -1) + self.assertEqual(math.ceil(-1.5), -1) + self.assertEqual(math.ceil(0.0), 0) + self.assertEqual(math.ceil(-0.0), 0) + #self.assertEqual(math.ceil(INF), INF) + #self.assertEqual(math.ceil(NINF), NINF) + #self.assertTrue(math.isnan(math.ceil(NAN))) + + class TestCeil: + def __ceil__(self): + return 42 + class FloatCeil(float): + def __ceil__(self): + return 42 + class TestNoCeil: + pass + self.assertEqual(math.ceil(TestCeil()), 42) + self.assertEqual(math.ceil(FloatCeil()), 42) + self.assertEqual(math.ceil(FloatLike(42.5)), 43) + self.assertRaises(TypeError, math.ceil, TestNoCeil()) + + t = TestNoCeil() + t.__ceil__ = lambda *args: args + self.assertRaises(TypeError, math.ceil, t) + self.assertRaises(TypeError, math.ceil, t, 0) + + # TODO Rustpython + # @requires_IEEE_754 + def testCopysign(self): + self.assertEqual(math.copysign(1, 42), 1.0) + self.assertEqual(math.copysign(0., 42), 0.0) + self.assertEqual(math.copysign(1., -42), -1.0) + self.assertEqual(math.copysign(3, 0.), 3.0) + self.assertEqual(math.copysign(4., -0.), -4.0) + + self.assertRaises(TypeError, math.copysign) + # copysign should let us distinguish signs of zeros + self.assertEqual(math.copysign(1., 0.), 1.) + self.assertEqual(math.copysign(1., -0.), -1.) + self.assertEqual(math.copysign(INF, 0.), INF) + self.assertEqual(math.copysign(INF, -0.), NINF) + self.assertEqual(math.copysign(NINF, 0.), INF) + self.assertEqual(math.copysign(NINF, -0.), NINF) + # and of infinities + self.assertEqual(math.copysign(1., INF), 1.) + self.assertEqual(math.copysign(1., NINF), -1.) + self.assertEqual(math.copysign(INF, INF), INF) + self.assertEqual(math.copysign(INF, NINF), NINF) + self.assertEqual(math.copysign(NINF, INF), INF) + self.assertEqual(math.copysign(NINF, NINF), NINF) + self.assertTrue(math.isnan(math.copysign(NAN, 1.))) + self.assertTrue(math.isnan(math.copysign(NAN, INF))) + self.assertTrue(math.isnan(math.copysign(NAN, NINF))) + self.assertTrue(math.isnan(math.copysign(NAN, NAN))) + # copysign(INF, NAN) may be INF or it may be NINF, since + # we don't know whether the sign bit of NAN is set on any + # given platform. + self.assertTrue(math.isinf(math.copysign(INF, NAN))) + # similarly, copysign(2., NAN) could be 2. or -2. + self.assertEqual(abs(math.copysign(2., NAN)), 2.) + + @unittest.skip('TODO: RustPython') + def testCos(self): + self.assertRaises(TypeError, math.cos) + self.ftest('cos(-pi/2)', math.cos(-math.pi/2), 0, abs_tol=math.ulp(1)) + self.ftest('cos(0)', math.cos(0), 1) + self.ftest('cos(pi/2)', math.cos(math.pi/2), 0, abs_tol=math.ulp(1)) + self.ftest('cos(pi)', math.cos(math.pi), -1) + try: + self.assertTrue(math.isnan(math.cos(INF))) + self.assertTrue(math.isnan(math.cos(NINF))) + except ValueError: + self.assertRaises(ValueError, math.cos, INF) + self.assertRaises(ValueError, math.cos, NINF) + self.assertTrue(math.isnan(math.cos(NAN))) + + @unittest.skipIf(sys.platform == 'win32' and platform.machine() in ('ARM', 'ARM64'), + "Windows UCRT is off by 2 ULP this test requires accuracy within 1 ULP") + def testCosh(self): + self.assertRaises(TypeError, math.cosh) + self.ftest('cosh(0)', math.cosh(0), 1) + self.ftest('cosh(2)-2*cosh(1)**2', math.cosh(2)-2*math.cosh(1)**2, -1) # Thanks to Lambert + self.assertEqual(math.cosh(INF), INF) + self.assertEqual(math.cosh(NINF), INF) + self.assertTrue(math.isnan(math.cosh(NAN))) + + def testDegrees(self): + self.assertRaises(TypeError, math.degrees) + self.ftest('degrees(pi)', math.degrees(math.pi), 180.0) + self.ftest('degrees(pi/2)', math.degrees(math.pi/2), 90.0) + self.ftest('degrees(-pi/4)', math.degrees(-math.pi/4), -45.0) + self.ftest('degrees(0)', math.degrees(0), 0) + + @unittest.skip('TODO RustPython') + def testExp(self): + self.assertRaises(TypeError, math.exp) + self.ftest('exp(-1)', math.exp(-1), 1/math.e) + self.ftest('exp(0)', math.exp(0), 1) + self.ftest('exp(1)', math.exp(1), math.e) + self.assertEqual(math.exp(INF), INF) + self.assertEqual(math.exp(NINF), 0.) + self.assertTrue(math.isnan(math.exp(NAN))) + self.assertRaises(OverflowError, math.exp, 1000000) + + def testFabs(self): + self.assertRaises(TypeError, math.fabs) + self.ftest('fabs(-1)', math.fabs(-1), 1) + self.ftest('fabs(0)', math.fabs(0), 0) + self.ftest('fabs(1)', math.fabs(1), 1) + + def testFactorial(self): + self.assertEqual(math.factorial(0), 1) + total = 1 + for i in range(1, 1000): + total *= i + self.assertEqual(math.factorial(i), total) + self.assertEqual(math.factorial(i), py_factorial(i)) + self.assertRaises(ValueError, math.factorial, -1) + self.assertRaises(ValueError, math.factorial, -10**100) + + @unittest.skip('TODO: RustPython') + def testFactorialNonIntegers(self): + with self.assertWarns(DeprecationWarning): + self.assertEqual(math.factorial(5.0), 120) + with self.assertWarns(DeprecationWarning): + self.assertRaises(ValueError, math.factorial, 5.2) + with self.assertWarns(DeprecationWarning): + self.assertRaises(ValueError, math.factorial, -1.0) + with self.assertWarns(DeprecationWarning): + self.assertRaises(ValueError, math.factorial, -1e100) + self.assertRaises(TypeError, math.factorial, decimal.Decimal('5')) + self.assertRaises(TypeError, math.factorial, decimal.Decimal('5.2')) + self.assertRaises(TypeError, math.factorial, "5") + + # Other implementations may place different upper bounds. + @support.cpython_only + def testFactorialHugeInputs(self): + # Currently raises OverflowError for inputs that are too large + # to fit into a C long. + self.assertRaises(OverflowError, math.factorial, 10**100) + with self.assertWarns(DeprecationWarning): + self.assertRaises(OverflowError, math.factorial, 1e100) + + @unittest.skip('TODO RustPython') + def testFloor(self): + self.assertRaises(TypeError, math.floor) + self.assertEqual(int, type(math.floor(0.5))) + self.assertEqual(math.floor(0.5), 0) + self.assertEqual(math.floor(1.0), 1) + self.assertEqual(math.floor(1.5), 1) + self.assertEqual(math.floor(-0.5), -1) + self.assertEqual(math.floor(-1.0), -1) + self.assertEqual(math.floor(-1.5), -2) + #self.assertEqual(math.ceil(INF), INF) + #self.assertEqual(math.ceil(NINF), NINF) + #self.assertTrue(math.isnan(math.floor(NAN))) + + class TestFloor: + def __floor__(self): + return 42 + class FloatFloor(float): + def __floor__(self): + return 42 + class TestNoFloor: + pass + self.assertEqual(math.floor(TestFloor()), 42) + self.assertEqual(math.floor(FloatFloor()), 42) + self.assertEqual(math.floor(FloatLike(41.9)), 41) + self.assertRaises(TypeError, math.floor, TestNoFloor()) + + t = TestNoFloor() + t.__floor__ = lambda *args: args + self.assertRaises(TypeError, math.floor, t) + self.assertRaises(TypeError, math.floor, t, 0) + + def testFmod(self): + self.assertRaises(TypeError, math.fmod) + self.ftest('fmod(10, 1)', math.fmod(10, 1), 0.0) + self.ftest('fmod(10, 0.5)', math.fmod(10, 0.5), 0.0) + self.ftest('fmod(10, 1.5)', math.fmod(10, 1.5), 1.0) + self.ftest('fmod(-10, 1)', math.fmod(-10, 1), -0.0) + self.ftest('fmod(-10, 0.5)', math.fmod(-10, 0.5), -0.0) + self.ftest('fmod(-10, 1.5)', math.fmod(-10, 1.5), -1.0) + self.assertTrue(math.isnan(math.fmod(NAN, 1.))) + self.assertTrue(math.isnan(math.fmod(1., NAN))) + self.assertTrue(math.isnan(math.fmod(NAN, NAN))) + self.assertRaises(ValueError, math.fmod, 1., 0.) + self.assertRaises(ValueError, math.fmod, INF, 1.) + self.assertRaises(ValueError, math.fmod, NINF, 1.) + self.assertRaises(ValueError, math.fmod, INF, 0.) + self.assertEqual(math.fmod(3.0, INF), 3.0) + self.assertEqual(math.fmod(-3.0, INF), -3.0) + self.assertEqual(math.fmod(3.0, NINF), 3.0) + self.assertEqual(math.fmod(-3.0, NINF), -3.0) + self.assertEqual(math.fmod(0.0, 3.0), 0.0) + self.assertEqual(math.fmod(0.0, NINF), 0.0) + + def testFrexp(self): + self.assertRaises(TypeError, math.frexp) + + def testfrexp(name, result, expected): + (mant, exp), (emant, eexp) = result, expected + if abs(mant-emant) > eps or exp != eexp: + self.fail('%s returned %r, expected %r'%\ + (name, result, expected)) + + testfrexp('frexp(-1)', math.frexp(-1), (-0.5, 1)) + testfrexp('frexp(0)', math.frexp(0), (0, 0)) + testfrexp('frexp(1)', math.frexp(1), (0.5, 1)) + testfrexp('frexp(2)', math.frexp(2), (0.5, 2)) + + self.assertEqual(math.frexp(INF)[0], INF) + self.assertEqual(math.frexp(NINF)[0], NINF) + self.assertTrue(math.isnan(math.frexp(NAN)[0])) + + + # TODO Rustpython + # @requires_IEEE_754 + # @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + # "fsum is not exact on machines with double rounding") + # def testFsum(self): + # # math.fsum relies on exact rounding for correct operation. + # # There's a known problem with IA32 floating-point that causes + # # inexact rounding in some situations, and will cause the + # # math.fsum tests below to fail; see issue #2937. On non IEEE + # # 754 platforms, and on IEEE 754 platforms that exhibit the + # # problem described in issue #2937, we simply skip the whole + # # test. + + # # Python version of math.fsum, for comparison. Uses a + # # different algorithm based on frexp, ldexp and integer + # # arithmetic. + # from sys import float_info + # mant_dig = float_info.mant_dig + # etiny = float_info.min_exp - mant_dig + + # def msum(iterable): + # """Full precision summation. Compute sum(iterable) without any + # intermediate accumulation of error. Based on the 'lsum' function + # at http://code.activestate.com/recipes/393090/ + # """ + # tmant, texp = 0, 0 + # for x in iterable: + # mant, exp = math.frexp(x) + # mant, exp = int(math.ldexp(mant, mant_dig)), exp - mant_dig + # if texp > exp: + # tmant <<= texp-exp + # texp = exp + # else: + # mant <<= exp-texp + # tmant += mant + # # Round tmant * 2**texp to a float. The original recipe + # # used float(str(tmant)) * 2.0**texp for this, but that's + # # a little unsafe because str -> float conversion can't be + # # relied upon to do correct rounding on all platforms. + # tail = max(len(bin(abs(tmant)))-2 - mant_dig, etiny - texp) + # if tail > 0: + # h = 1 << (tail-1) + # tmant = tmant // (2*h) + bool(tmant & h and tmant & 3*h-1) + # texp += tail + # return math.ldexp(tmant, texp) + + # test_values = [ + # ([], 0.0), + # ([0.0], 0.0), + # ([1e100, 1.0, -1e100, 1e-100, 1e50, -1.0, -1e50], 1e-100), + # ([2.0**53, -0.5, -2.0**-54], 2.0**53-1.0), + # ([2.0**53, 1.0, 2.0**-100], 2.0**53+2.0), + # ([2.0**53+10.0, 1.0, 2.0**-100], 2.0**53+12.0), + # ([2.0**53-4.0, 0.5, 2.0**-54], 2.0**53-3.0), + # ([1./n for n in range(1, 1001)], + # float.fromhex('0x1.df11f45f4e61ap+2')), + # ([(-1.)**n/n for n in range(1, 1001)], + # float.fromhex('-0x1.62a2af1bd3624p-1')), + # ([1e16, 1., 1e-16], 10000000000000002.0), + # ([1e16-2., 1.-2.**-53, -(1e16-2.), -(1.-2.**-53)], 0.0), + # # exercise code for resizing partials array + # ([2.**n - 2.**(n+50) + 2.**(n+52) for n in range(-1074, 972, 2)] + + # [-2.**1022], + # float.fromhex('0x1.5555555555555p+970')), + # ] + + # # Telescoping sum, with exact differences (due to Sterbenz) + # terms = [1.7**i for i in range(1001)] + # test_values.append(( + # [terms[i+1] - terms[i] for i in range(1000)] + [-terms[1000]], + # -terms[0] + # )) + + # for i, (vals, expected) in enumerate(test_values): + # try: + # actual = math.fsum(vals) + # except OverflowError: + # self.fail("test %d failed: got OverflowError, expected %r " + # "for math.fsum(%.100r)" % (i, expected, vals)) + # except ValueError: + # self.fail("test %d failed: got ValueError, expected %r " + # "for math.fsum(%.100r)" % (i, expected, vals)) + # self.assertEqual(actual, expected) + + # from random import random, gauss, shuffle + # for j in range(1000): + # vals = [7, 1e100, -7, -1e100, -9e-20, 8e-20] * 10 + # s = 0 + # for i in range(200): + # v = gauss(0, random()) ** 7 - s + # s += v + # vals.append(v) + # shuffle(vals) + + # s = msum(vals) + # self.assertEqual(msum(vals), math.fsum(vals)) + + + # Python 3.9 + def testGcd(self): + gcd = math.gcd + self.assertEqual(gcd(0, 0), 0) + self.assertEqual(gcd(1, 0), 1) + self.assertEqual(gcd(-1, 0), 1) + self.assertEqual(gcd(0, 1), 1) + self.assertEqual(gcd(0, -1), 1) + self.assertEqual(gcd(7, 1), 1) + self.assertEqual(gcd(7, -1), 1) + self.assertEqual(gcd(-23, 15), 1) + self.assertEqual(gcd(120, 84), 12) + self.assertEqual(gcd(84, -120), 12) + self.assertEqual(gcd(1216342683557601535506311712, + 436522681849110124616458784), 32) + + x = 434610456570399902378880679233098819019853229470286994367836600566 + y = 1064502245825115327754847244914921553977 + for c in (652560, + 576559230871654959816130551884856912003141446781646602790216406874): + a = x * c + b = y * c + self.assertEqual(gcd(a, b), c) + self.assertEqual(gcd(b, a), c) + self.assertEqual(gcd(-a, b), c) + self.assertEqual(gcd(b, -a), c) + self.assertEqual(gcd(a, -b), c) + self.assertEqual(gcd(-b, a), c) + self.assertEqual(gcd(-a, -b), c) + self.assertEqual(gcd(-b, -a), c) + + self.assertEqual(gcd(), 0) + self.assertEqual(gcd(120), 120) + self.assertEqual(gcd(-120), 120) + self.assertEqual(gcd(120, 84, 102), 6) + self.assertEqual(gcd(120, 1, 84), 1) + + self.assertRaises(TypeError, gcd, 120.0) + self.assertRaises(TypeError, gcd, 120.0, 84) + self.assertRaises(TypeError, gcd, 120, 84.0) + self.assertRaises(TypeError, gcd, 120, 1, 84.0) + #self.assertEqual(gcd(MyIndexable(120), MyIndexable(84)), 12) # TODO RustPython + + @unittest.skip('TODO: RustPython float support') + def testHypot(self): + from decimal import Decimal + from fractions import Fraction + + hypot = math.hypot + + # Test different numbers of arguments (from zero to five) + # against a straightforward pure python implementation + args = math.e, math.pi, math.sqrt(2.0), math.gamma(3.5), math.sin(2.1) + for i in range(len(args)+1): + self.assertAlmostEqual( + hypot(*args[:i]), + math.sqrt(sum(s**2 for s in args[:i])) + ) + + # Test allowable types (those with __float__) + self.assertEqual(hypot(12.0, 5.0), 13.0) + self.assertEqual(hypot(12, 5), 13) + self.assertEqual(hypot(Decimal(12), Decimal(5)), 13) + self.assertEqual(hypot(Fraction(12, 32), Fraction(5, 32)), Fraction(13, 32)) + self.assertEqual(hypot(bool(1), bool(0), bool(1), bool(1)), math.sqrt(3)) + + # Test corner cases + self.assertEqual(hypot(0.0, 0.0), 0.0) # Max input is zero + self.assertEqual(hypot(-10.5), 10.5) # Negative input + self.assertEqual(hypot(), 0.0) # Negative input + self.assertEqual(1.0, + math.copysign(1.0, hypot(-0.0)) # Convert negative zero to positive zero + ) + self.assertEqual( # Handling of moving max to the end + hypot(1.5, 1.5, 0.5), + hypot(1.5, 0.5, 1.5), + ) + + # Test handling of bad arguments + with self.assertRaises(TypeError): # Reject keyword args + hypot(x=1) + with self.assertRaises(TypeError): # Reject values without __float__ + hypot(1.1, 'string', 2.2) + int_too_big_for_float = 10 ** (sys.float_info.max_10_exp + 5) + with self.assertRaises((ValueError, OverflowError)): + hypot(1, int_too_big_for_float) + + # Any infinity gives positive infinity. + self.assertEqual(hypot(INF), INF) + self.assertEqual(hypot(0, INF), INF) + self.assertEqual(hypot(10, INF), INF) + self.assertEqual(hypot(-10, INF), INF) + self.assertEqual(hypot(NAN, INF), INF) + self.assertEqual(hypot(INF, NAN), INF) + self.assertEqual(hypot(NINF, NAN), INF) + self.assertEqual(hypot(NAN, NINF), INF) + self.assertEqual(hypot(-INF, INF), INF) + self.assertEqual(hypot(-INF, -INF), INF) + self.assertEqual(hypot(10, -INF), INF) + + # If no infinity, any NaN gives a NaN. + self.assertTrue(math.isnan(hypot(NAN))) + self.assertTrue(math.isnan(hypot(0, NAN))) + self.assertTrue(math.isnan(hypot(NAN, 10))) + self.assertTrue(math.isnan(hypot(10, NAN))) + self.assertTrue(math.isnan(hypot(NAN, NAN))) + self.assertTrue(math.isnan(hypot(NAN))) + + # Verify scaling for extremely large values + fourthmax = FLOAT_MAX / 4.0 + for n in range(32): + self.assertEqual(hypot(*([fourthmax]*n)), fourthmax * math.sqrt(n)) + + # Verify scaling for extremely small values + for exp in range(32): + scale = FLOAT_MIN / 2.0 ** exp + self.assertEqual(math.hypot(4*scale, 3*scale), 5*scale) + + @unittest.skip('TODO: RustPython') + def testDist(self): + from decimal import Decimal as D + from fractions import Fraction as F + + dist = math.dist + sqrt = math.sqrt + + # Simple exact cases + self.assertEqual(dist((1.0, 2.0, 3.0), (4.0, 2.0, -1.0)), 5.0) + self.assertEqual(dist((1, 2, 3), (4, 2, -1)), 5.0) + + # Test different numbers of arguments (from zero to nine) + # against a straightforward pure python implementation + for i in range(9): + for j in range(5): + p = tuple(random.uniform(-5, 5) for k in range(i)) + q = tuple(random.uniform(-5, 5) for k in range(i)) + self.assertAlmostEqual( + dist(p, q), + sqrt(sum((px - qx) ** 2.0 for px, qx in zip(p, q))) + ) + + # Test non-tuple inputs + self.assertEqual(dist([1.0, 2.0, 3.0], [4.0, 2.0, -1.0]), 5.0) + self.assertEqual(dist(iter([1.0, 2.0, 3.0]), iter([4.0, 2.0, -1.0])), 5.0) + + # Test allowable types (those with __float__) + self.assertEqual(dist((14.0, 1.0), (2.0, -4.0)), 13.0) + self.assertEqual(dist((14, 1), (2, -4)), 13) + self.assertEqual(dist((D(14), D(1)), (D(2), D(-4))), D(13)) + self.assertEqual(dist((F(14, 32), F(1, 32)), (F(2, 32), F(-4, 32))), + F(13, 32)) + self.assertEqual(dist((True, True, False, True, False), + (True, False, True, True, False)), + sqrt(2.0)) + + # Test corner cases + self.assertEqual(dist((13.25, 12.5, -3.25), + (13.25, 12.5, -3.25)), + 0.0) # Distance with self is zero + self.assertEqual(dist((), ()), 0.0) # Zero-dimensional case + self.assertEqual(1.0, # Convert negative zero to positive zero + math.copysign(1.0, dist((-0.0,), (0.0,))) + ) + self.assertEqual(1.0, # Convert negative zero to positive zero + math.copysign(1.0, dist((0.0,), (-0.0,))) + ) + self.assertEqual( # Handling of moving max to the end + dist((1.5, 1.5, 0.5), (0, 0, 0)), + dist((1.5, 0.5, 1.5), (0, 0, 0)) + ) + + # Verify tuple subclasses are allowed + class T(tuple): + pass + self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0) + + # Test handling of bad arguments + with self.assertRaises(TypeError): # Reject keyword args + dist(p=(1, 2, 3), q=(4, 5, 6)) + with self.assertRaises(TypeError): # Too few args + dist((1, 2, 3)) + with self.assertRaises(TypeError): # Too many args + dist((1, 2, 3), (4, 5, 6), (7, 8, 9)) + with self.assertRaises(TypeError): # Scalars not allowed + dist(1, 2) + with self.assertRaises(TypeError): # Reject values without __float__ + dist((1.1, 'string', 2.2), (1, 2, 3)) + with self.assertRaises(ValueError): # Check dimension agree + dist((1, 2, 3, 4), (5, 6, 7)) + with self.assertRaises(ValueError): # Check dimension agree + dist((1, 2, 3), (4, 5, 6, 7)) + with self.assertRaises(TypeError): # Rejects invalid types + dist("abc", "xyz") + int_too_big_for_float = 10 ** (sys.float_info.max_10_exp + 5) + with self.assertRaises((ValueError, OverflowError)): + dist((1, int_too_big_for_float), (2, 3)) + with self.assertRaises((ValueError, OverflowError)): + dist((2, 3), (1, int_too_big_for_float)) + + # Verify that the one dimensional case is equivalent to abs() + for i in range(20): + p, q = random.random(), random.random() + self.assertEqual(dist((p,), (q,)), abs(p - q)) + + # Test special values + values = [NINF, -10.5, -0.0, 0.0, 10.5, INF, NAN] + for p in itertools.product(values, repeat=3): + for q in itertools.product(values, repeat=3): + diffs = [px - qx for px, qx in zip(p, q)] + if any(map(math.isinf, diffs)): + # Any infinite difference gives positive infinity. + self.assertEqual(dist(p, q), INF) + elif any(map(math.isnan, diffs)): + # If no infinity, any NaN gives a NaN. + self.assertTrue(math.isnan(dist(p, q))) + + # Verify scaling for extremely large values + fourthmax = FLOAT_MAX / 4.0 + for n in range(32): + p = (fourthmax,) * n + q = (0.0,) * n + self.assertEqual(dist(p, q), fourthmax * math.sqrt(n)) + self.assertEqual(dist(q, p), fourthmax * math.sqrt(n)) + + # Verify scaling for extremely small values + for exp in range(32): + scale = FLOAT_MIN / 2.0 ** exp + p = (4*scale, 3*scale) + q = (0.0, 0.0) + self.assertEqual(math.dist(p, q), 5*scale) + self.assertEqual(math.dist(q, p), 5*scale) + + @unittest.skip('TODO RustPython') + def testIsqrt(self): + # Test a variety of inputs, large and small. + test_values = ( + list(range(1000)) + + list(range(10**6 - 1000, 10**6 + 1000)) + + [2**e + i for e in range(60, 200) for i in range(-40, 40)] + + [3**9999, 10**5001] + ) + + for value in test_values: + with self.subTest(value=value): + s = math.isqrt(value) + self.assertIs(type(s), int) + self.assertLessEqual(s*s, value) + self.assertLess(value, (s+1)*(s+1)) + + # Negative values + with self.assertRaises(ValueError): + math.isqrt(-1) + + # Integer-like things + s = math.isqrt(True) + self.assertIs(type(s), int) + self.assertEqual(s, 1) + + s = math.isqrt(False) + self.assertIs(type(s), int) + self.assertEqual(s, 0) + + class IntegerLike(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + + s = math.isqrt(IntegerLike(1729)) + self.assertIs(type(s), int) + self.assertEqual(s, 41) + + with self.assertRaises(ValueError): + math.isqrt(IntegerLike(-3)) + + # Non-integer-like things + bad_values = [ + 3.5, "a string", decimal.Decimal("3.5"), 3.5j, + 100.0, -4.0, + ] + for value in bad_values: + with self.subTest(value=value): + with self.assertRaises(TypeError): + math.isqrt(value) + + # Python 3.9 + def testlcm(self): + lcm = math.lcm + self.assertEqual(lcm(0, 0), 0) + self.assertEqual(lcm(1, 0), 0) + self.assertEqual(lcm(-1, 0), 0) + self.assertEqual(lcm(0, 1), 0) + self.assertEqual(lcm(0, -1), 0) + self.assertEqual(lcm(7, 1), 7) + self.assertEqual(lcm(7, -1), 7) + self.assertEqual(lcm(-23, 15), 345) + self.assertEqual(lcm(120, 84), 840) + self.assertEqual(lcm(84, -120), 840) + self.assertEqual(lcm(1216342683557601535506311712, + 436522681849110124616458784), + 16592536571065866494401400422922201534178938447014944) + + x = 43461045657039990237 + y = 10645022458251153277 + for c in (652560, + 57655923087165495981): + a = x * c + b = y * c + d = x * y * c + self.assertEqual(lcm(a, b), d) + self.assertEqual(lcm(b, a), d) + self.assertEqual(lcm(-a, b), d) + self.assertEqual(lcm(b, -a), d) + self.assertEqual(lcm(a, -b), d) + self.assertEqual(lcm(-b, a), d) + self.assertEqual(lcm(-a, -b), d) + self.assertEqual(lcm(-b, -a), d) + + self.assertEqual(lcm(), 1) + self.assertEqual(lcm(120), 120) + self.assertEqual(lcm(-120), 120) + self.assertEqual(lcm(120, 84, 102), 14280) + self.assertEqual(lcm(120, 0, 84), 0) + + self.assertRaises(TypeError, lcm, 120.0) + self.assertRaises(TypeError, lcm, 120.0, 84) + self.assertRaises(TypeError, lcm, 120, 84.0) + self.assertRaises(TypeError, lcm, 120, 0, 84.0) + # self.assertEqual(lcm(MyIndexable(120), MyIndexable(84)), 840) # TODO RustPython + + @unittest.skip('TODO RustPython') + def testLdexp(self): + self.assertRaises(TypeError, math.ldexp) + self.ftest('ldexp(0,1)', math.ldexp(0,1), 0) + self.ftest('ldexp(1,1)', math.ldexp(1,1), 2) + self.ftest('ldexp(1,-1)', math.ldexp(1,-1), 0.5) + self.ftest('ldexp(-1,1)', math.ldexp(-1,1), -2) + self.assertRaises(OverflowError, math.ldexp, 1., 1000000) + self.assertRaises(OverflowError, math.ldexp, -1., 1000000) + self.assertEqual(math.ldexp(1., -1000000), 0.) + self.assertEqual(math.ldexp(-1., -1000000), -0.) + self.assertEqual(math.ldexp(INF, 30), INF) + self.assertEqual(math.ldexp(NINF, -213), NINF) + self.assertTrue(math.isnan(math.ldexp(NAN, 0))) + + # large second argument + for n in [10**5, 10**10, 10**20, 10**40]: + self.assertEqual(math.ldexp(INF, -n), INF) + self.assertEqual(math.ldexp(NINF, -n), NINF) + self.assertEqual(math.ldexp(1., -n), 0.) + self.assertEqual(math.ldexp(-1., -n), -0.) + self.assertEqual(math.ldexp(0., -n), 0.) + self.assertEqual(math.ldexp(-0., -n), -0.) + self.assertTrue(math.isnan(math.ldexp(NAN, -n))) + + self.assertRaises(OverflowError, math.ldexp, 1., n) + self.assertRaises(OverflowError, math.ldexp, -1., n) + self.assertEqual(math.ldexp(0., n), 0.) + self.assertEqual(math.ldexp(-0., n), -0.) + self.assertEqual(math.ldexp(INF, n), INF) + self.assertEqual(math.ldexp(NINF, n), NINF) + self.assertTrue(math.isnan(math.ldexp(NAN, n))) + + @unittest.skip('TODO RustPython') + def testLog(self): + self.assertRaises(TypeError, math.log) + self.ftest('log(1/e)', math.log(1/math.e), -1) + self.ftest('log(1)', math.log(1), 0) + self.ftest('log(e)', math.log(math.e), 1) + self.ftest('log(32,2)', math.log(32,2), 5) + self.ftest('log(10**40, 10)', math.log(10**40, 10), 40) + self.ftest('log(10**40, 10**20)', math.log(10**40, 10**20), 2) + self.ftest('log(10**1000)', math.log(10**1000), + 2302.5850929940457) + self.assertRaises(ValueError, math.log, -1.5) + self.assertRaises(ValueError, math.log, -10**1000) + self.assertRaises(ValueError, math.log, NINF) + self.assertEqual(math.log(INF), INF) + self.assertTrue(math.isnan(math.log(NAN))) + + @unittest.skip('TODO RustPython') + def testLog1p(self): + self.assertRaises(TypeError, math.log1p) + for n in [2, 2**90, 2**300]: + self.assertAlmostEqual(math.log1p(n), math.log1p(float(n))) + self.assertRaises(ValueError, math.log1p, -1) + self.assertEqual(math.log1p(INF), INF) + + # TODO Rustpython + # @requires_IEEE_754 + # def testLog2(self): + # self.assertRaises(TypeError, math.log2) + + # # Check some integer values + # self.assertEqual(math.log2(1), 0.0) + # self.assertEqual(math.log2(2), 1.0) + # self.assertEqual(math.log2(4), 2.0) + + # # Large integer values + # self.assertEqual(math.log2(2**1023), 1023.0) + # self.assertEqual(math.log2(2**1024), 1024.0) + # self.assertEqual(math.log2(2**2000), 2000.0) + + # self.assertRaises(ValueError, math.log2, -1.5) + # self.assertRaises(ValueError, math.log2, NINF) + # self.assertTrue(math.isnan(math.log2(NAN))) + + # TODO Rustpython + # @requires_IEEE_754 + # # log2() is not accurate enough on Mac OS X Tiger (10.4) + # @support.requires_mac_ver(10, 5) + # def testLog2Exact(self): + # # Check that we get exact equality for log2 of powers of 2. + # actual = [math.log2(math.ldexp(1.0, n)) for n in range(-1074, 1024)] + # expected = [float(n) for n in range(-1074, 1024)] + # self.assertEqual(actual, expected) + + # def testLog10(self): + # self.assertRaises(TypeError, math.log10) + # self.ftest('log10(0.1)', math.log10(0.1), -1) + # self.ftest('log10(1)', math.log10(1), 0) + # self.ftest('log10(10)', math.log10(10), 1) + # self.ftest('log10(10**1000)', math.log10(10**1000), 1000.0) + # self.assertRaises(ValueError, math.log10, -1.5) + # self.assertRaises(ValueError, math.log10, -10**1000) + # self.assertRaises(ValueError, math.log10, NINF) + # self.assertEqual(math.log(INF), INF) + # self.assertTrue(math.isnan(math.log10(NAN))) + + def testModf(self): + self.assertRaises(TypeError, math.modf) + + def testmodf(name, result, expected): + (v1, v2), (e1, e2) = result, expected + if abs(v1-e1) > eps or abs(v2-e2): + self.fail('%s returned %r, expected %r'%\ + (name, result, expected)) + + testmodf('modf(1.5)', math.modf(1.5), (0.5, 1.0)) + testmodf('modf(-1.5)', math.modf(-1.5), (-0.5, -1.0)) + + self.assertEqual(math.modf(INF), (0.0, INF)) + self.assertEqual(math.modf(NINF), (-0.0, NINF)) + + modf_nan = math.modf(NAN) + self.assertTrue(math.isnan(modf_nan[0])) + self.assertTrue(math.isnan(modf_nan[1])) + + @unittest.skip('TODO RustPython') + def testPow(self): + self.assertRaises(TypeError, math.pow) + self.ftest('pow(0,1)', math.pow(0,1), 0) + self.ftest('pow(1,0)', math.pow(1,0), 1) + self.ftest('pow(2,1)', math.pow(2,1), 2) + self.ftest('pow(2,-1)', math.pow(2,-1), 0.5) + self.assertEqual(math.pow(INF, 1), INF) + self.assertEqual(math.pow(NINF, 1), NINF) + self.assertEqual((math.pow(1, INF)), 1.) + self.assertEqual((math.pow(1, NINF)), 1.) + self.assertTrue(math.isnan(math.pow(NAN, 1))) + self.assertTrue(math.isnan(math.pow(2, NAN))) + self.assertTrue(math.isnan(math.pow(0, NAN))) + self.assertEqual(math.pow(1, NAN), 1) + + # pow(0., x) + self.assertEqual(math.pow(0., INF), 0.) + self.assertEqual(math.pow(0., 3.), 0.) + self.assertEqual(math.pow(0., 2.3), 0.) + self.assertEqual(math.pow(0., 2.), 0.) + self.assertEqual(math.pow(0., 0.), 1.) + self.assertEqual(math.pow(0., -0.), 1.) + self.assertRaises(ValueError, math.pow, 0., -2.) + self.assertRaises(ValueError, math.pow, 0., -2.3) + self.assertRaises(ValueError, math.pow, 0., -3.) + self.assertRaises(ValueError, math.pow, 0., NINF) + self.assertTrue(math.isnan(math.pow(0., NAN))) + + # pow(INF, x) + self.assertEqual(math.pow(INF, INF), INF) + self.assertEqual(math.pow(INF, 3.), INF) + self.assertEqual(math.pow(INF, 2.3), INF) + self.assertEqual(math.pow(INF, 2.), INF) + self.assertEqual(math.pow(INF, 0.), 1.) + self.assertEqual(math.pow(INF, -0.), 1.) + self.assertEqual(math.pow(INF, -2.), 0.) + self.assertEqual(math.pow(INF, -2.3), 0.) + self.assertEqual(math.pow(INF, -3.), 0.) + self.assertEqual(math.pow(INF, NINF), 0.) + self.assertTrue(math.isnan(math.pow(INF, NAN))) + + # pow(-0., x) + self.assertEqual(math.pow(-0., INF), 0.) + self.assertEqual(math.pow(-0., 3.), -0.) + self.assertEqual(math.pow(-0., 2.3), 0.) + self.assertEqual(math.pow(-0., 2.), 0.) + self.assertEqual(math.pow(-0., 0.), 1.) + self.assertEqual(math.pow(-0., -0.), 1.) + self.assertRaises(ValueError, math.pow, -0., -2.) + self.assertRaises(ValueError, math.pow, -0., -2.3) + self.assertRaises(ValueError, math.pow, -0., -3.) + self.assertRaises(ValueError, math.pow, -0., NINF) + self.assertTrue(math.isnan(math.pow(-0., NAN))) + + # pow(NINF, x) + self.assertEqual(math.pow(NINF, INF), INF) + self.assertEqual(math.pow(NINF, 3.), NINF) + self.assertEqual(math.pow(NINF, 2.3), INF) + self.assertEqual(math.pow(NINF, 2.), INF) + self.assertEqual(math.pow(NINF, 0.), 1.) + self.assertEqual(math.pow(NINF, -0.), 1.) + self.assertEqual(math.pow(NINF, -2.), 0.) + self.assertEqual(math.pow(NINF, -2.3), 0.) + self.assertEqual(math.pow(NINF, -3.), -0.) + self.assertEqual(math.pow(NINF, NINF), 0.) + self.assertTrue(math.isnan(math.pow(NINF, NAN))) + + # pow(-1, x) + self.assertEqual(math.pow(-1., INF), 1.) + self.assertEqual(math.pow(-1., 3.), -1.) + self.assertRaises(ValueError, math.pow, -1., 2.3) + self.assertEqual(math.pow(-1., 2.), 1.) + self.assertEqual(math.pow(-1., 0.), 1.) + self.assertEqual(math.pow(-1., -0.), 1.) + self.assertEqual(math.pow(-1., -2.), 1.) + self.assertRaises(ValueError, math.pow, -1., -2.3) + self.assertEqual(math.pow(-1., -3.), -1.) + self.assertEqual(math.pow(-1., NINF), 1.) + self.assertTrue(math.isnan(math.pow(-1., NAN))) + + # pow(1, x) + self.assertEqual(math.pow(1., INF), 1.) + self.assertEqual(math.pow(1., 3.), 1.) + self.assertEqual(math.pow(1., 2.3), 1.) + self.assertEqual(math.pow(1., 2.), 1.) + self.assertEqual(math.pow(1., 0.), 1.) + self.assertEqual(math.pow(1., -0.), 1.) + self.assertEqual(math.pow(1., -2.), 1.) + self.assertEqual(math.pow(1., -2.3), 1.) + self.assertEqual(math.pow(1., -3.), 1.) + self.assertEqual(math.pow(1., NINF), 1.) + self.assertEqual(math.pow(1., NAN), 1.) + + # pow(x, 0) should be 1 for any x + self.assertEqual(math.pow(2.3, 0.), 1.) + self.assertEqual(math.pow(-2.3, 0.), 1.) + self.assertEqual(math.pow(NAN, 0.), 1.) + self.assertEqual(math.pow(2.3, -0.), 1.) + self.assertEqual(math.pow(-2.3, -0.), 1.) + self.assertEqual(math.pow(NAN, -0.), 1.) + + # pow(x, y) is invalid if x is negative and y is not integral + self.assertRaises(ValueError, math.pow, -1., 2.3) + self.assertRaises(ValueError, math.pow, -15., -3.1) + + # pow(x, NINF) + self.assertEqual(math.pow(1.9, NINF), 0.) + self.assertEqual(math.pow(1.1, NINF), 0.) + self.assertEqual(math.pow(0.9, NINF), INF) + self.assertEqual(math.pow(0.1, NINF), INF) + self.assertEqual(math.pow(-0.1, NINF), INF) + self.assertEqual(math.pow(-0.9, NINF), INF) + self.assertEqual(math.pow(-1.1, NINF), 0.) + self.assertEqual(math.pow(-1.9, NINF), 0.) + + # pow(x, INF) + self.assertEqual(math.pow(1.9, INF), INF) + self.assertEqual(math.pow(1.1, INF), INF) + self.assertEqual(math.pow(0.9, INF), 0.) + self.assertEqual(math.pow(0.1, INF), 0.) + self.assertEqual(math.pow(-0.1, INF), 0.) + self.assertEqual(math.pow(-0.9, INF), 0.) + self.assertEqual(math.pow(-1.1, INF), INF) + self.assertEqual(math.pow(-1.9, INF), INF) + + # pow(x, y) should work for x negative, y an integer + self.ftest('(-2.)**3.', math.pow(-2.0, 3.0), -8.0) + self.ftest('(-2.)**2.', math.pow(-2.0, 2.0), 4.0) + self.ftest('(-2.)**1.', math.pow(-2.0, 1.0), -2.0) + self.ftest('(-2.)**0.', math.pow(-2.0, 0.0), 1.0) + self.ftest('(-2.)**-0.', math.pow(-2.0, -0.0), 1.0) + self.ftest('(-2.)**-1.', math.pow(-2.0, -1.0), -0.5) + self.ftest('(-2.)**-2.', math.pow(-2.0, -2.0), 0.25) + self.ftest('(-2.)**-3.', math.pow(-2.0, -3.0), -0.125) + self.assertRaises(ValueError, math.pow, -2.0, -0.5) + self.assertRaises(ValueError, math.pow, -2.0, 0.5) + + # the following tests have been commented out since they don't + # really belong here: the implementation of ** for floats is + # independent of the implementation of math.pow + #self.assertEqual(1**NAN, 1) + #self.assertEqual(1**INF, 1) + #self.assertEqual(1**NINF, 1) + #self.assertEqual(1**0, 1) + #self.assertEqual(1.**NAN, 1) + #self.assertEqual(1.**INF, 1) + #self.assertEqual(1.**NINF, 1) + #self.assertEqual(1.**0, 1) + + def testRadians(self): + self.assertRaises(TypeError, math.radians) + self.ftest('radians(180)', math.radians(180), math.pi) + self.ftest('radians(90)', math.radians(90), math.pi/2) + self.ftest('radians(-45)', math.radians(-45), -math.pi/4) + self.ftest('radians(0)', math.radians(0), 0) + + # TODO Rustpython + # @requires_IEEE_754 + # def testRemainder(self): + # from fractions import Fraction + + # def validate_spec(x, y, r): + # """ + # Check that r matches remainder(x, y) according to the IEEE 754 + # specification. Assumes that x, y and r are finite and y is nonzero. + # """ + # fx, fy, fr = Fraction(x), Fraction(y), Fraction(r) + # # r should not exceed y/2 in absolute value + # self.assertLessEqual(abs(fr), abs(fy/2)) + # # x - r should be an exact integer multiple of y + # n = (fx - fr) / fy + # self.assertEqual(n, int(n)) + # if abs(fr) == abs(fy/2): + # # If |r| == |y/2|, n should be even. + # self.assertEqual(n/2, int(n/2)) + + # # triples (x, y, remainder(x, y)) in hexadecimal form. + # testcases = [ + # # Remainders modulo 1, showing the ties-to-even behaviour. + # '-4.0 1 -0.0', + # '-3.8 1 0.8', + # '-3.0 1 -0.0', + # '-2.8 1 -0.8', + # '-2.0 1 -0.0', + # '-1.8 1 0.8', + # '-1.0 1 -0.0', + # '-0.8 1 -0.8', + # '-0.0 1 -0.0', + # ' 0.0 1 0.0', + # ' 0.8 1 0.8', + # ' 1.0 1 0.0', + # ' 1.8 1 -0.8', + # ' 2.0 1 0.0', + # ' 2.8 1 0.8', + # ' 3.0 1 0.0', + # ' 3.8 1 -0.8', + # ' 4.0 1 0.0', + + # # Reductions modulo 2*pi + # '0x0.0p+0 0x1.921fb54442d18p+2 0x0.0p+0', + # '0x1.921fb54442d18p+0 0x1.921fb54442d18p+2 0x1.921fb54442d18p+0', + # '0x1.921fb54442d17p+1 0x1.921fb54442d18p+2 0x1.921fb54442d17p+1', + # '0x1.921fb54442d18p+1 0x1.921fb54442d18p+2 0x1.921fb54442d18p+1', + # '0x1.921fb54442d19p+1 0x1.921fb54442d18p+2 -0x1.921fb54442d17p+1', + # '0x1.921fb54442d17p+2 0x1.921fb54442d18p+2 -0x0.0000000000001p+2', + # '0x1.921fb54442d18p+2 0x1.921fb54442d18p+2 0x0p0', + # '0x1.921fb54442d19p+2 0x1.921fb54442d18p+2 0x0.0000000000001p+2', + # '0x1.2d97c7f3321d1p+3 0x1.921fb54442d18p+2 0x1.921fb54442d14p+1', + # '0x1.2d97c7f3321d2p+3 0x1.921fb54442d18p+2 -0x1.921fb54442d18p+1', + # '0x1.2d97c7f3321d3p+3 0x1.921fb54442d18p+2 -0x1.921fb54442d14p+1', + # '0x1.921fb54442d17p+3 0x1.921fb54442d18p+2 -0x0.0000000000001p+3', + # '0x1.921fb54442d18p+3 0x1.921fb54442d18p+2 0x0p0', + # '0x1.921fb54442d19p+3 0x1.921fb54442d18p+2 0x0.0000000000001p+3', + # '0x1.f6a7a2955385dp+3 0x1.921fb54442d18p+2 0x1.921fb54442d14p+1', + # '0x1.f6a7a2955385ep+3 0x1.921fb54442d18p+2 0x1.921fb54442d18p+1', + # '0x1.f6a7a2955385fp+3 0x1.921fb54442d18p+2 -0x1.921fb54442d14p+1', + # '0x1.1475cc9eedf00p+5 0x1.921fb54442d18p+2 0x1.921fb54442d10p+1', + # '0x1.1475cc9eedf01p+5 0x1.921fb54442d18p+2 -0x1.921fb54442d10p+1', + + # # Symmetry with respect to signs. + # ' 1 0.c 0.4', + # '-1 0.c -0.4', + # ' 1 -0.c 0.4', + # '-1 -0.c -0.4', + # ' 1.4 0.c -0.4', + # '-1.4 0.c 0.4', + # ' 1.4 -0.c -0.4', + # '-1.4 -0.c 0.4', + + # # Huge modulus, to check that the underlying algorithm doesn't + # # rely on 2.0 * modulus being representable. + # '0x1.dp+1023 0x1.4p+1023 0x0.9p+1023', + # '0x1.ep+1023 0x1.4p+1023 -0x0.ap+1023', + # '0x1.fp+1023 0x1.4p+1023 -0x0.9p+1023', + # ] + + # for case in testcases: + # with self.subTest(case=case): + # x_hex, y_hex, expected_hex = case.split() + # x = float.fromhex(x_hex) + # y = float.fromhex(y_hex) + # expected = float.fromhex(expected_hex) + # validate_spec(x, y, expected) + # actual = math.remainder(x, y) + # # Cheap way of checking that the floats are + # # as identical as we need them to be. + # self.assertEqual(actual.hex(), expected.hex()) + + # # Test tiny subnormal modulus: there's potential for + # # getting the implementation wrong here (for example, + # # by assuming that modulus/2 is exactly representable). + # tiny = float.fromhex('1p-1074') # min +ve subnormal + # for n in range(-25, 25): + # if n == 0: + # continue + # y = n * tiny + # for m in range(100): + # x = m * tiny + # actual = math.remainder(x, y) + # validate_spec(x, y, actual) + # actual = math.remainder(-x, y) + # validate_spec(-x, y, actual) + + # # Special values. + # # NaNs should propagate as usual. + # for value in [NAN, 0.0, -0.0, 2.0, -2.3, NINF, INF]: + # self.assertIsNaN(math.remainder(NAN, value)) + # self.assertIsNaN(math.remainder(value, NAN)) + + # # remainder(x, inf) is x, for non-nan non-infinite x. + # for value in [-2.3, -0.0, 0.0, 2.3]: + # self.assertEqual(math.remainder(value, INF), value) + # self.assertEqual(math.remainder(value, NINF), value) + + # # remainder(x, 0) and remainder(infinity, x) for non-NaN x are invalid + # # operations according to IEEE 754-2008 7.2(f), and should raise. + # for value in [NINF, -2.3, -0.0, 0.0, 2.3, INF]: + # with self.assertRaises(ValueError): + # math.remainder(INF, value) + # with self.assertRaises(ValueError): + # math.remainder(NINF, value) + # with self.assertRaises(ValueError): + # math.remainder(value, 0.0) + # with self.assertRaises(ValueError): + # math.remainder(value, -0.0) + + def testSin(self): + self.assertRaises(TypeError, math.sin) + self.ftest('sin(0)', math.sin(0), 0) + self.ftest('sin(pi/2)', math.sin(math.pi/2), 1) + self.ftest('sin(-pi/2)', math.sin(-math.pi/2), -1) + try: + self.assertTrue(math.isnan(math.sin(INF))) + self.assertTrue(math.isnan(math.sin(NINF))) + except ValueError: + self.assertRaises(ValueError, math.sin, INF) + self.assertRaises(ValueError, math.sin, NINF) + self.assertTrue(math.isnan(math.sin(NAN))) + + def testSinh(self): + self.assertRaises(TypeError, math.sinh) + self.ftest('sinh(0)', math.sinh(0), 0) + self.ftest('sinh(1)**2-cosh(1)**2', math.sinh(1)**2-math.cosh(1)**2, -1) + self.ftest('sinh(1)+sinh(-1)', math.sinh(1)+math.sinh(-1), 0) + self.assertEqual(math.sinh(INF), INF) + self.assertEqual(math.sinh(NINF), NINF) + self.assertTrue(math.isnan(math.sinh(NAN))) + + @unittest.skip('TODO RustPython') + def testSqrt(self): + self.assertRaises(TypeError, math.sqrt) + self.ftest('sqrt(0)', math.sqrt(0), 0) + self.ftest('sqrt(1)', math.sqrt(1), 1) + self.ftest('sqrt(4)', math.sqrt(4), 2) + self.assertEqual(math.sqrt(INF), INF) + self.assertRaises(ValueError, math.sqrt, -1) + self.assertRaises(ValueError, math.sqrt, NINF) + self.assertTrue(math.isnan(math.sqrt(NAN))) + + def testTan(self): + self.assertRaises(TypeError, math.tan) + self.ftest('tan(0)', math.tan(0), 0) + self.ftest('tan(pi/4)', math.tan(math.pi/4), 1) + self.ftest('tan(-pi/4)', math.tan(-math.pi/4), -1) + try: + self.assertTrue(math.isnan(math.tan(INF))) + self.assertTrue(math.isnan(math.tan(NINF))) + except: + self.assertRaises(ValueError, math.tan, INF) + self.assertRaises(ValueError, math.tan, NINF) + self.assertTrue(math.isnan(math.tan(NAN))) + + @unittest.skip('TODO RustPython') + def testTanh(self): + self.assertRaises(TypeError, math.tanh) + self.ftest('tanh(0)', math.tanh(0), 0) + self.ftest('tanh(1)+tanh(-1)', math.tanh(1)+math.tanh(-1), 0, + abs_tol=math.ulp(1)) + self.ftest('tanh(inf)', math.tanh(INF), 1) + self.ftest('tanh(-inf)', math.tanh(NINF), -1) + self.assertTrue(math.isnan(math.tanh(NAN))) + + # TODO Rustpython + # @requires_IEEE_754 + # def testTanhSign(self): + # # check that tanh(-0.) == -0. on IEEE 754 systems + # self.assertEqual(math.tanh(-0.), -0.) + # self.assertEqual(math.copysign(1., math.tanh(-0.)), + # math.copysign(1., -0.)) + + def test_trunc(self): + self.assertEqual(math.trunc(1), 1) + self.assertEqual(math.trunc(-1), -1) + self.assertEqual(type(math.trunc(1)), int) + self.assertEqual(type(math.trunc(1.5)), int) + self.assertEqual(math.trunc(1.5), 1) + self.assertEqual(math.trunc(-1.5), -1) + self.assertEqual(math.trunc(1.999999), 1) + self.assertEqual(math.trunc(-1.999999), -1) + self.assertEqual(math.trunc(-0.999999), -0) + self.assertEqual(math.trunc(-100.999), -100) + + class TestTrunc: + def __trunc__(self): + return 23 + class FloatTrunc(float): + def __trunc__(self): + return 23 + class TestNoTrunc: + pass + + self.assertEqual(math.trunc(TestTrunc()), 23) + self.assertEqual(math.trunc(FloatTrunc()), 23) + + self.assertRaises(TypeError, math.trunc) + self.assertRaises(TypeError, math.trunc, 1, 2) + self.assertRaises(TypeError, math.trunc, FloatLike(23.5)) + self.assertRaises(TypeError, math.trunc, TestNoTrunc()) + + def testIsfinite(self): + self.assertTrue(math.isfinite(0.0)) + self.assertTrue(math.isfinite(-0.0)) + self.assertTrue(math.isfinite(1.0)) + self.assertTrue(math.isfinite(-1.0)) + self.assertFalse(math.isfinite(float("nan"))) + self.assertFalse(math.isfinite(float("inf"))) + self.assertFalse(math.isfinite(float("-inf"))) + + def testIsnan(self): + self.assertTrue(math.isnan(float("nan"))) + self.assertTrue(math.isnan(float("-nan"))) + self.assertTrue(math.isnan(float("inf") * 0.)) + self.assertFalse(math.isnan(float("inf"))) + self.assertFalse(math.isnan(0.)) + self.assertFalse(math.isnan(1.)) + + def testIsinf(self): + self.assertTrue(math.isinf(float("inf"))) + self.assertTrue(math.isinf(float("-inf"))) + self.assertTrue(math.isinf(1E400)) + self.assertTrue(math.isinf(-1E400)) + self.assertFalse(math.isinf(float("nan"))) + self.assertFalse(math.isinf(0.)) + self.assertFalse(math.isinf(1.)) + + # TODO Rustpython + # @requires_IEEE_754 + # def test_nan_constant(self): + # self.assertTrue(math.isnan(math.nan)) + + # TODO Rustpython + # @requires_IEEE_754 + # def test_inf_constant(self): + # self.assertTrue(math.isinf(math.inf)) + # self.assertGreater(math.inf, 0.0) + # self.assertEqual(math.inf, float("inf")) + # self.assertEqual(-math.inf, float("-inf")) + + # RED_FLAG 16-Oct-2000 Tim + # While 2.0 is more consistent about exceptions than previous releases, it + # still fails this part of the test on some platforms. For now, we only + # *run* test_exceptions() in verbose mode, so that this isn't normally + # tested. + @unittest.skip('TODO RustPython') + @unittest.skipUnless(verbose, 'requires verbose mode') + def test_exceptions(self): + try: + x = math.exp(-1000000000) + except: + # mathmodule.c is failing to weed out underflows from libm, or + # we've got an fp format with huge dynamic range + self.fail("underflowing exp() should not have raised " + "an exception") + if x != 0: + self.fail("underflowing exp() should have returned 0") + + # If this fails, probably using a strict IEEE-754 conforming libm, and x + # is +Inf afterwards. But Python wants overflows detected by default. + try: + x = math.exp(1000000000) + except OverflowError: + pass + else: + self.fail("overflowing exp() didn't trigger OverflowError") + + # If this fails, it could be a puzzle. One odd possibility is that + # mathmodule.c's macros are getting confused while comparing + # Inf (HUGE_VAL) to a NaN, and artificially setting errno to ERANGE + # as a result (and so raising OverflowError instead). + try: + x = math.sqrt(-1.0) + except ValueError: + pass + else: + self.fail("sqrt(-1) didn't raise ValueError") + + # TODO Rustpython + # @requires_IEEE_754 + # def test_testfile(self): + # # Some tests need to be skipped on ancient OS X versions. + # # See issue #27953. + # SKIP_ON_TIGER = {'tan0064'} + + # osx_version = None + # if sys.platform == 'darwin': + # version_txt = platform.mac_ver()[0] + # try: + # osx_version = tuple(map(int, version_txt.split('.'))) + # except ValueError: + # pass + + # fail_fmt = "{}: {}({!r}): {}" + + # failures = [] + # for id, fn, ar, ai, er, ei, flags in parse_testfile(test_file): + # # Skip if either the input or result is complex + # if ai != 0.0 or ei != 0.0: + # continue + # if fn in ['rect', 'polar']: + # # no real versions of rect, polar + # continue + # # Skip certain tests on OS X 10.4. + # if osx_version is not None and osx_version < (10, 5): + # if id in SKIP_ON_TIGER: + # continue + + # func = getattr(math, fn) + + # if 'invalid' in flags or 'divide-by-zero' in flags: + # er = 'ValueError' + # elif 'overflow' in flags: + # er = 'OverflowError' + + # try: + # result = func(ar) + # except ValueError: + # result = 'ValueError' + # except OverflowError: + # result = 'OverflowError' + + # # Default tolerances + # ulp_tol, abs_tol = 5, 0.0 + + # failure = result_check(er, result, ulp_tol, abs_tol) + # if failure is None: + # continue + + # msg = fail_fmt.format(id, fn, ar, failure) + # failures.append(msg) + + # if failures: + # self.fail('Failures in test_testfile:\n ' + + # '\n '.join(failures)) + + # TODO Rustpython + # @requires_IEEE_754 + # def test_mtestfile(self): + # fail_fmt = "{}: {}({!r}): {}" + + # failures = [] + # for id, fn, arg, expected, flags in parse_mtestfile(math_testcases): + # func = getattr(math, fn) + + # if 'invalid' in flags or 'divide-by-zero' in flags: + # expected = 'ValueError' + # elif 'overflow' in flags: + # expected = 'OverflowError' + + # try: + # got = func(arg) + # except ValueError: + # got = 'ValueError' + # except OverflowError: + # got = 'OverflowError' + + # # Default tolerances + # ulp_tol, abs_tol = 5, 0.0 + + # # Exceptions to the defaults + # if fn == 'gamma': + # # Experimental results on one platform gave + # # an accuracy of <= 10 ulps across the entire float + # # domain. We weaken that to require 20 ulp accuracy. + # ulp_tol = 20 + + # elif fn == 'lgamma': + # # we use a weaker accuracy test for lgamma; + # # lgamma only achieves an absolute error of + # # a few multiples of the machine accuracy, in + # # general. + # abs_tol = 1e-15 + + # elif fn == 'erfc' and arg >= 0.0: + # # erfc has less-than-ideal accuracy for large + # # arguments (x ~ 25 or so), mainly due to the + # # error involved in computing exp(-x*x). + # # + # # Observed between CPython and mpmath at 25 dp: + # # x < 0 : err <= 2 ulp + # # 0 <= x < 1 : err <= 10 ulp + # # 1 <= x < 10 : err <= 100 ulp + # # 10 <= x < 20 : err <= 300 ulp + # # 20 <= x : < 600 ulp + # # + # if arg < 1.0: + # ulp_tol = 10 + # elif arg < 10.0: + # ulp_tol = 100 + # else: + # ulp_tol = 1000 + + # failure = result_check(expected, got, ulp_tol, abs_tol) + # if failure is None: + # continue + + # msg = fail_fmt.format(id, fn, arg, failure) + # failures.append(msg) + + # if failures: + # self.fail('Failures in test_mtestfile:\n ' + + # '\n '.join(failures)) + + @unittest.skip('TODO RustPython') + def test_prod(self): + prod = math.prod + self.assertEqual(prod([]), 1) + self.assertEqual(prod([], start=5), 5) + self.assertEqual(prod(list(range(2,8))), 5040) + self.assertEqual(prod(iter(list(range(2,8)))), 5040) + self.assertEqual(prod(range(1, 10), start=10), 3628800) + + self.assertEqual(prod([1, 2, 3, 4, 5]), 120) + self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0) + + # Test overflow in fast-path for integers + self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32) + # Test overflow in fast-path for floats + self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32)) + + self.assertRaises(TypeError, prod) + self.assertRaises(TypeError, prod, 42) + self.assertRaises(TypeError, prod, ['a', 'b', 'c']) + self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '') + self.assertRaises(TypeError, prod, [b'a', b'c'], b'') + values = [bytearray(b'a'), bytearray(b'b')] + self.assertRaises(TypeError, prod, values, bytearray(b'')) + self.assertRaises(TypeError, prod, [[1], [2], [3]]) + self.assertRaises(TypeError, prod, [{2:3}]) + self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3}) + self.assertRaises(TypeError, prod, [[1], [2], [3]], []) + with self.assertRaises(TypeError): + prod([10, 20], [30, 40]) # start is a keyword-only argument + + self.assertEqual(prod([0, 1, 2, 3]), 0) + self.assertEqual(prod([1, 0, 2, 3]), 0) + self.assertEqual(prod([1, 2, 3, 0]), 0) + + def _naive_prod(iterable, start=1): + for elem in iterable: + start *= elem + return start + + # Big integers + + iterable = range(1, 10000) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-10000, -1) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-1000, 1000) + self.assertEqual(prod(iterable), 0) + + # Big floats + + iterable = [float(x) for x in range(1, 1000)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, -1)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, 1000)] + self.assertIsNaN(prod(iterable)) + + # Float tests + + self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, 0, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, float("nan"), 0, 3])) + self.assertIsNaN(prod([1, float("inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("-inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("nan"), float("inf"),3])) + self.assertIsNaN(prod([1, float("nan"), float("-inf"),3])) + + self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf')) + self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf')) + + self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4])) + self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4])) + self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3])) + self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2])) + + # Type preservation + + self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int) + self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float) + self.assertEqual(type(prod(range(1, 10000))), int) + self.assertEqual(type(prod(range(1, 10000), start=1.0)), float) + self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])), + decimal.Decimal) + + @unittest.skip('TODO RustPython') + def testPerm(self): + perm = math.perm + factorial = math.factorial + # Test if factorial definition is satisfied + for n in range(100): + for k in range(n + 1): + self.assertEqual(perm(n, k), + factorial(n) // factorial(n - k)) + + # Test for Pascal's identity + for n in range(1, 100): + for k in range(1, n): + self.assertEqual(perm(n, k), perm(n - 1, k - 1) * k + perm(n - 1, k)) + + # Test corner cases + for n in range(1, 100): + self.assertEqual(perm(n, 0), 1) + self.assertEqual(perm(n, 1), n) + self.assertEqual(perm(n, n), factorial(n)) + + # Test one argument form + for n in range(20): + self.assertEqual(perm(n), factorial(n)) + self.assertEqual(perm(n, None), factorial(n)) + + # Raises TypeError if any argument is non-integer or argument count is + # not 1 or 2 + self.assertRaises(TypeError, perm, 10, 1.0) + self.assertRaises(TypeError, perm, 10, decimal.Decimal(1.0)) + self.assertRaises(TypeError, perm, 10, "1") + self.assertRaises(TypeError, perm, 10.0, 1) + self.assertRaises(TypeError, perm, decimal.Decimal(10.0), 1) + self.assertRaises(TypeError, perm, "10", 1) + + self.assertRaises(TypeError, perm) + self.assertRaises(TypeError, perm, 10, 1, 3) + self.assertRaises(TypeError, perm) + + # Raises Value error if not k or n are negative numbers + self.assertRaises(ValueError, perm, -1, 1) + self.assertRaises(ValueError, perm, -2**1000, 1) + self.assertRaises(ValueError, perm, 1, -1) + self.assertRaises(ValueError, perm, 1, -2**1000) + + # Returns zero if k is greater than n + self.assertEqual(perm(1, 2), 0) + self.assertEqual(perm(1, 2**1000), 0) + + n = 2**1000 + self.assertEqual(perm(n, 0), 1) + self.assertEqual(perm(n, 1), n) + self.assertEqual(perm(n, 2), n * (n-1)) + if support.check_impl_detail(cpython=True): + self.assertRaises(OverflowError, perm, n, n) + + for n, k in (True, True), (True, False), (False, False): + self.assertEqual(perm(n, k), 1) + self.assertIs(type(perm(n, k)), int) + self.assertEqual(perm(IntSubclass(5), IntSubclass(2)), 20) + self.assertEqual(perm(MyIndexable(5), MyIndexable(2)), 20) + for k in range(3): + self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int) + + @unittest.skip('TODO RustPython') + def testComb(self): + comb = math.comb + factorial = math.factorial + # Test if factorial definition is satisfied + for n in range(100): + for k in range(n + 1): + self.assertEqual(comb(n, k), factorial(n) + // (factorial(k) * factorial(n - k))) + + # Test for Pascal's identity + for n in range(1, 100): + for k in range(1, n): + self.assertEqual(comb(n, k), comb(n - 1, k - 1) + comb(n - 1, k)) + + # Test corner cases + for n in range(100): + self.assertEqual(comb(n, 0), 1) + self.assertEqual(comb(n, n), 1) + + for n in range(1, 100): + self.assertEqual(comb(n, 1), n) + self.assertEqual(comb(n, n - 1), n) + + # Test Symmetry + for n in range(100): + for k in range(n // 2): + self.assertEqual(comb(n, k), comb(n, n - k)) + + # Raises TypeError if any argument is non-integer or argument count is + # not 2 + self.assertRaises(TypeError, comb, 10, 1.0) + self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0)) + self.assertRaises(TypeError, comb, 10, "1") + self.assertRaises(TypeError, comb, 10.0, 1) + self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1) + self.assertRaises(TypeError, comb, "10", 1) + + self.assertRaises(TypeError, comb, 10) + self.assertRaises(TypeError, comb, 10, 1, 3) + self.assertRaises(TypeError, comb) + + # Raises Value error if not k or n are negative numbers + self.assertRaises(ValueError, comb, -1, 1) + self.assertRaises(ValueError, comb, -2**1000, 1) + self.assertRaises(ValueError, comb, 1, -1) + self.assertRaises(ValueError, comb, 1, -2**1000) + + # Returns zero if k is greater than n + self.assertEqual(comb(1, 2), 0) + self.assertEqual(comb(1, 2**1000), 0) + + n = 2**1000 + self.assertEqual(comb(n, 0), 1) + self.assertEqual(comb(n, 1), n) + self.assertEqual(comb(n, 2), n * (n-1) // 2) + self.assertEqual(comb(n, n), 1) + self.assertEqual(comb(n, n-1), n) + self.assertEqual(comb(n, n-2), n * (n-1) // 2) + if support.check_impl_detail(cpython=True): + self.assertRaises(OverflowError, comb, n, n//2) + + for n, k in (True, True), (True, False), (False, False): + self.assertEqual(comb(n, k), 1) + self.assertIs(type(comb(n, k)), int) + self.assertEqual(comb(IntSubclass(5), IntSubclass(2)), 10) + self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10) + for k in range(3): + self.assertIs(type(comb(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(comb(MyIndexable(5), MyIndexable(k))), int) + + # TODO Rustpython + # @requires_IEEE_754 + # def test_nextafter(self): + # # around 2^52 and 2^63 + # self.assertEqual(math.nextafter(4503599627370496.0, -INF), + # 4503599627370495.5) + # self.assertEqual(math.nextafter(4503599627370496.0, INF), + # 4503599627370497.0) + # self.assertEqual(math.nextafter(9223372036854775808.0, 0.0), + # 9223372036854774784.0) + # self.assertEqual(math.nextafter(-9223372036854775808.0, 0.0), + # -9223372036854774784.0) + + # # around 1.0 + # self.assertEqual(math.nextafter(1.0, -INF), + # float.fromhex('0x1.fffffffffffffp-1')) + # self.assertEqual(math.nextafter(1.0, INF), + # float.fromhex('0x1.0000000000001p+0')) + + # # x == y: y is returned + # self.assertEqual(math.nextafter(2.0, 2.0), 2.0) + # self.assertEqualSign(math.nextafter(-0.0, +0.0), +0.0) + # self.assertEqualSign(math.nextafter(+0.0, -0.0), -0.0) + + # # around 0.0 + # smallest_subnormal = sys.float_info.min * sys.float_info.epsilon + # self.assertEqual(math.nextafter(+0.0, INF), smallest_subnormal) + # self.assertEqual(math.nextafter(-0.0, INF), smallest_subnormal) + # self.assertEqual(math.nextafter(+0.0, -INF), -smallest_subnormal) + # self.assertEqual(math.nextafter(-0.0, -INF), -smallest_subnormal) + # self.assertEqualSign(math.nextafter(smallest_subnormal, +0.0), +0.0) + # self.assertEqualSign(math.nextafter(-smallest_subnormal, +0.0), -0.0) + # self.assertEqualSign(math.nextafter(smallest_subnormal, -0.0), +0.0) + # self.assertEqualSign(math.nextafter(-smallest_subnormal, -0.0), -0.0) + + # # around infinity + # largest_normal = sys.float_info.max + # self.assertEqual(math.nextafter(INF, 0.0), largest_normal) + # self.assertEqual(math.nextafter(-INF, 0.0), -largest_normal) + # self.assertEqual(math.nextafter(largest_normal, INF), INF) + # self.assertEqual(math.nextafter(-largest_normal, -INF), -INF) + + # # NaN + # self.assertIsNaN(math.nextafter(NAN, 1.0)) + # self.assertIsNaN(math.nextafter(1.0, NAN)) + # self.assertIsNaN(math.nextafter(NAN, NAN)) + + # @requires_IEEE_754 + # def test_ulp(self): + # self.assertEqual(math.ulp(1.0), sys.float_info.epsilon) + # # use int ** int rather than float ** int to not rely on pow() accuracy + # self.assertEqual(math.ulp(2 ** 52), 1.0) + # self.assertEqual(math.ulp(2 ** 53), 2.0) + # self.assertEqual(math.ulp(2 ** 64), 4096.0) + + # # min and max + # self.assertEqual(math.ulp(0.0), + # sys.float_info.min * sys.float_info.epsilon) + # self.assertEqual(math.ulp(FLOAT_MAX), + # FLOAT_MAX - math.nextafter(FLOAT_MAX, -INF)) + + # # special cases + # self.assertEqual(math.ulp(INF), INF) + # self.assertIsNaN(math.ulp(math.nan)) + + # # negative number: ulp(-x) == ulp(x) + # for x in (0.0, 1.0, 2 ** 52, 2 ** 64, INF): + # with self.subTest(x=x): + # self.assertEqual(math.ulp(-x), math.ulp(x)) + + @unittest.skip('TODO RustPython') + def test_issue39871(self): + # A SystemError should not be raised if the first arg to atan2(), + # copysign(), or remainder() cannot be converted to a float. + class F: + def __float__(self): + self.converted = True + 1/0 + for func in math.atan2, math.copysign, math.remainder: + y = F() + with self.assertRaises(TypeError): + func("not a number", y) + + # There should not have been any attempt to convert the second + # argument to a float. + self.assertFalse(getattr(y, "converted", False)) + + # Custom assertions. + + def assertIsNaN(self, value): + if not math.isnan(value): + self.fail("Expected a NaN, got {!r}.".format(value)) + + def assertEqualSign(self, x, y): + """Similar to assertEqual(), but compare also the sign with copysign(). + Function useful to compare signed zeros. + """ + self.assertEqual(x, y) + self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y)) + + +class IsCloseTests(unittest.TestCase): + isclose = math.isclose # subclasses should override this + + def assertIsClose(self, a, b, *args, **kwargs): + self.assertTrue(self.isclose(a, b, *args, **kwargs), + msg="%s and %s should be close!" % (a, b)) + + def assertIsNotClose(self, a, b, *args, **kwargs): + self.assertFalse(self.isclose(a, b, *args, **kwargs), + msg="%s and %s should not be close!" % (a, b)) + + def assertAllClose(self, examples, *args, **kwargs): + for a, b in examples: + self.assertIsClose(a, b, *args, **kwargs) + + def assertAllNotClose(self, examples, *args, **kwargs): + for a, b in examples: + self.assertIsNotClose(a, b, *args, **kwargs) + + def test_negative_tolerances(self): + # ValueError should be raised if either tolerance is less than zero + with self.assertRaises(ValueError): + self.assertIsClose(1, 1, rel_tol=-1e-100) + with self.assertRaises(ValueError): + self.assertIsClose(1, 1, rel_tol=1e-100, abs_tol=-1e10) + + def test_identical(self): + # identical values must test as close + identical_examples = [(2.0, 2.0), + (0.1e200, 0.1e200), + (1.123e-300, 1.123e-300), + (12345, 12345.0), + (0.0, -0.0), + (345678, 345678)] + self.assertAllClose(identical_examples, rel_tol=0.0, abs_tol=0.0) + + def test_eight_decimal_places(self): + # examples that are close to 1e-8, but not 1e-9 + eight_decimal_places_examples = [(1e8, 1e8 + 1), + (-1e-8, -1.000000009e-8), + (1.12345678, 1.12345679)] + self.assertAllClose(eight_decimal_places_examples, rel_tol=1e-8) + self.assertAllNotClose(eight_decimal_places_examples, rel_tol=1e-9) + + def test_near_zero(self): + # values close to zero + near_zero_examples = [(1e-9, 0.0), + (-1e-9, 0.0), + (-1e-150, 0.0)] + # these should not be close to any rel_tol + self.assertAllNotClose(near_zero_examples, rel_tol=0.9) + # these should be close to abs_tol=1e-8 + self.assertAllClose(near_zero_examples, abs_tol=1e-8) + + def test_identical_infinite(self): + # these are close regardless of tolerance -- i.e. they are equal + self.assertIsClose(INF, INF) + self.assertIsClose(INF, INF, abs_tol=0.0) + self.assertIsClose(NINF, NINF) + self.assertIsClose(NINF, NINF, abs_tol=0.0) + + def test_inf_ninf_nan(self): + # these should never be close (following IEEE 754 rules for equality) + not_close_examples = [(NAN, NAN), + (NAN, 1e-100), + (1e-100, NAN), + (INF, NAN), + (NAN, INF), + (INF, NINF), + (INF, 1.0), + (1.0, INF), + (INF, 1e308), + (1e308, INF)] + # use largest reasonable tolerance + self.assertAllNotClose(not_close_examples, abs_tol=0.999999999999999) + + def test_zero_tolerance(self): + # test with zero tolerance + zero_tolerance_close_examples = [(1.0, 1.0), + (-3.4, -3.4), + (-1e-300, -1e-300)] + self.assertAllClose(zero_tolerance_close_examples, rel_tol=0.0) + + zero_tolerance_not_close_examples = [(1.0, 1.000000000000001), + (0.99999999999999, 1.0), + (1.0e200, .999999999999999e200)] + self.assertAllNotClose(zero_tolerance_not_close_examples, rel_tol=0.0) + + def test_asymmetry(self): + # test the asymmetry example from PEP 485 + self.assertAllClose([(9, 10), (10, 9)], rel_tol=0.1) + + def test_integers(self): + # test with integer values + integer_examples = [(100000001, 100000000), + (123456789, 123456788)] + + self.assertAllClose(integer_examples, rel_tol=1e-8) + self.assertAllNotClose(integer_examples, rel_tol=1e-9) + + @unittest.skip('TODO RustPython') + def test_decimals(self): + # test with Decimal values + from decimal import Decimal + + decimal_examples = [(Decimal('1.00000001'), Decimal('1.0')), + (Decimal('1.00000001e-20'), Decimal('1.0e-20')), + (Decimal('1.00000001e-100'), Decimal('1.0e-100')), + (Decimal('1.00000001e20'), Decimal('1.0e20'))] + self.assertAllClose(decimal_examples, rel_tol=1e-8) + self.assertAllNotClose(decimal_examples, rel_tol=1e-9) + + @unittest.skip('TODO Rustpython') + def test_fractions(self): + # test with Fraction values + from fractions import Fraction + + fraction_examples = [ + (Fraction(1, 100000000) + 1, Fraction(1)), + (Fraction(100000001), Fraction(100000000)), + (Fraction(10**8 + 1, 10**28), Fraction(1, 10**20))] + self.assertAllClose(fraction_examples, rel_tol=1e-8) + self.assertAllNotClose(fraction_examples, rel_tol=1e-9) + + +def test_main(): + # from doctest import DocFileSuite + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(MathTests)) + suite.addTest(unittest.makeSuite(IsCloseTests)) + # suite.addTest(DocFileSuite("ieee754.txt")) + run_unittest(suite) + +if __name__ == '__main__': + test_main() \ No newline at end of file diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py new file mode 100644 index 0000000000..8e9c5d1ee3 --- /dev/null +++ b/Lib/test/test_thread.py @@ -0,0 +1,268 @@ +import os +import unittest +import random +from test import support +import _thread as thread +import time +import weakref + +from test import lock_tests + +NUMTASKS = 10 +NUMTRIPS = 3 +POLL_SLEEP = 0.010 # seconds = 10 ms + +_print_mutex = thread.allocate_lock() + +def verbose_print(arg): + """Helper function for printing out debugging output.""" + if support.verbose: + with _print_mutex: + print(arg) + + +class BasicThreadTest(unittest.TestCase): + + def setUp(self): + self.done_mutex = thread.allocate_lock() + self.done_mutex.acquire() + self.running_mutex = thread.allocate_lock() + self.random_mutex = thread.allocate_lock() + self.created = 0 + self.running = 0 + self.next_ident = 0 + + key = support.threading_setup() + self.addCleanup(support.threading_cleanup, *key) + + +class ThreadRunningTests(BasicThreadTest): + + def newtask(self): + with self.running_mutex: + self.next_ident += 1 + verbose_print("creating task %s" % self.next_ident) + thread.start_new_thread(self.task, (self.next_ident,)) + self.created += 1 + self.running += 1 + + def task(self, ident): + with self.random_mutex: + delay = random.random() / 10000.0 + verbose_print("task %s will run for %sus" % (ident, round(delay*1e6))) + time.sleep(delay) + verbose_print("task %s done" % ident) + with self.running_mutex: + self.running -= 1 + if self.created == NUMTASKS and self.running == 0: + self.done_mutex.release() + + def test_starting_threads(self): + with support.wait_threads_exit(): + # Basic test for thread creation. + for i in range(NUMTASKS): + self.newtask() + verbose_print("waiting for tasks to complete...") + self.done_mutex.acquire() + verbose_print("all tasks done") + + def test_stack_size(self): + # Various stack size tests. + self.assertEqual(thread.stack_size(), 0, "initial stack size is not 0") + + thread.stack_size(0) + self.assertEqual(thread.stack_size(), 0, "stack_size not reset to default") + + @unittest.skipIf(os.name not in ("nt", "posix"), 'test meant for nt and posix') + def test_nt_and_posix_stack_size(self): + try: + thread.stack_size(4096) + except ValueError: + verbose_print("caught expected ValueError setting " + "stack_size(4096)") + except thread.error: + self.skipTest("platform does not support changing thread stack " + "size") + + fail_msg = "stack_size(%d) failed - should succeed" + for tss in (262144, 0x100000, 0): + thread.stack_size(tss) + self.assertEqual(thread.stack_size(), tss, fail_msg % tss) + verbose_print("successfully set stack_size(%d)" % tss) + + for tss in (262144, 0x100000): + verbose_print("trying stack_size = (%d)" % tss) + self.next_ident = 0 + self.created = 0 + with support.wait_threads_exit(): + for i in range(NUMTASKS): + self.newtask() + + verbose_print("waiting for all tasks to complete") + self.done_mutex.acquire() + verbose_print("all tasks done") + + thread.stack_size(0) + + @unittest.skip("TODO: RUSTPYTHON, weakref destructors") + def test__count(self): + # Test the _count() function. + orig = thread._count() + mut = thread.allocate_lock() + mut.acquire() + started = [] + + def task(): + started.append(None) + mut.acquire() + mut.release() + + with support.wait_threads_exit(): + thread.start_new_thread(task, ()) + while not started: + time.sleep(POLL_SLEEP) + self.assertEqual(thread._count(), orig + 1) + # Allow the task to finish. + mut.release() + # The only reliable way to be sure that the thread ended from the + # interpreter's point of view is to wait for the function object to be + # destroyed. + done = [] + wr = weakref.ref(task, lambda _: done.append(None)) + del task + while not done: + time.sleep(POLL_SLEEP) + self.assertEqual(thread._count(), orig) + + @unittest.skip("TODO: RUSTPYTHON, sys.unraisablehook") + def test_unraisable_exception(self): + def task(): + started.release() + raise ValueError("task failed") + + started = thread.allocate_lock() + with support.catch_unraisable_exception() as cm: + with support.wait_threads_exit(): + started.acquire() + thread.start_new_thread(task, ()) + started.acquire() + + self.assertEqual(str(cm.unraisable.exc_value), "task failed") + self.assertIs(cm.unraisable.object, task) + self.assertEqual(cm.unraisable.err_msg, + "Exception ignored in thread started by") + self.assertIsNotNone(cm.unraisable.exc_traceback) + + +class Barrier: + def __init__(self, num_threads): + self.num_threads = num_threads + self.waiting = 0 + self.checkin_mutex = thread.allocate_lock() + self.checkout_mutex = thread.allocate_lock() + self.checkout_mutex.acquire() + + def enter(self): + self.checkin_mutex.acquire() + self.waiting = self.waiting + 1 + if self.waiting == self.num_threads: + self.waiting = self.num_threads - 1 + self.checkout_mutex.release() + return + self.checkin_mutex.release() + + self.checkout_mutex.acquire() + self.waiting = self.waiting - 1 + if self.waiting == 0: + self.checkin_mutex.release() + return + self.checkout_mutex.release() + + +class BarrierTest(BasicThreadTest): + + def test_barrier(self): + with support.wait_threads_exit(): + self.bar = Barrier(NUMTASKS) + self.running = NUMTASKS + for i in range(NUMTASKS): + thread.start_new_thread(self.task2, (i,)) + verbose_print("waiting for tasks to end") + self.done_mutex.acquire() + verbose_print("tasks done") + + def task2(self, ident): + for i in range(NUMTRIPS): + if ident == 0: + # give it a good chance to enter the next + # barrier before the others are all out + # of the current one + delay = 0 + else: + with self.random_mutex: + delay = random.random() / 10000.0 + verbose_print("task %s will run for %sus" % + (ident, round(delay * 1e6))) + time.sleep(delay) + verbose_print("task %s entering %s" % (ident, i)) + self.bar.enter() + verbose_print("task %s leaving barrier" % ident) + with self.running_mutex: + self.running -= 1 + # Must release mutex before releasing done, else the main thread can + # exit and set mutex to None as part of global teardown; then + # mutex.release() raises AttributeError. + finished = self.running == 0 + if finished: + self.done_mutex.release() + +class LockTests(lock_tests.LockTests): + locktype = thread.allocate_lock + + +class TestForkInThread(unittest.TestCase): + def setUp(self): + self.read_fd, self.write_fd = os.pipe() + + @unittest.skipUnless(hasattr(os, 'fork'), 'need os.fork') + @support.reap_threads + def test_forkinthread(self): + status = "not set" + + def thread1(): + nonlocal status + + # fork in a thread + pid = os.fork() + if pid == 0: + # child + try: + os.close(self.read_fd) + os.write(self.write_fd, b"OK") + finally: + os._exit(0) + else: + # parent + os.close(self.write_fd) + pid, status = os.waitpid(pid, 0) + + with support.wait_threads_exit(): + thread.start_new_thread(thread1, ()) + self.assertEqual(os.read(self.read_fd, 2), b"OK", + "Unable to fork() in thread") + self.assertEqual(status, 0) + + def tearDown(self): + try: + os.close(self.read_fd) + except OSError: + pass + + try: + os.close(self.write_fd) + except OSError: + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index cd509933cb..6e6e33f795 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -390,7 +390,6 @@ def test_float__format__locale(self): self.assertEqual(locale.format_string('%g', x, grouping=True), format(x, 'n')) self.assertEqual(locale.format_string('%.10g', x, grouping=True), format(x, '.10n')) - # TODO: RUSTPYTHON @unittest.expectedFailure @run_with_locale('LC_NUMERIC', 'en_US.UTF8') def test_int__format__locale(self): diff --git a/Lib/threading.py b/Lib/threading.py index 69c8e10eba..bb41456fb1 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -2,7 +2,6 @@ import os as _os import sys as _sys -import _rp_thread # Hack: Trigger populating of RustPython _thread with dummies import _thread from time import monotonic as _time diff --git a/README.md b/README.md index fe1ed44c94..a9eda87f56 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ For this Fork #### Check out our [online demo](https://rustpython.github.io/demo/) running on WebAssembly. -RustPython requires Rust latest stable version (e.g 1.38.0 at Oct 1st 2019). +RustPython requires Rust latest stable version (e.g 1.43.0 at May 24th 2020). To check Rust version: `rustc --version` If you wish to update, `rustup update stable`. diff --git a/compiler/src/compile.rs b/compiler/src/compile.rs index 42c78b2897..53aa36267d 100644 --- a/compiler/src/compile.rs +++ b/compiler/src/compile.rs @@ -1374,7 +1374,7 @@ impl Compiler { for (i, element) in elements.iter().enumerate() { if let ast::ExpressionType::Starred { .. } = &element.node { if seen_star { - return Err(self.error(CompileErrorType::StarArgs)); + return Err(self.error(CompileErrorType::MultipleStarArgs)); } else { seen_star = true; self.emit(Instruction::UnpackEx { @@ -1399,7 +1399,14 @@ impl Compiler { } } } - _ => return Err(self.error(CompileErrorType::Assign(target.name()))), + _ => { + return Err(self.error(match target.node { + ast::ExpressionType::Starred { .. } => CompileErrorType::SyntaxError( + "starred assignment target must be in a list or tuple".to_owned(), + ), + _ => CompileErrorType::Assign(target.name()), + })) + } } Ok(()) @@ -1782,11 +1789,7 @@ impl Compiler { self.compile_comprehension(kind, generators)?; } Starred { .. } => { - return Err( - self.error(CompileErrorType::SyntaxError(std::string::String::from( - "Invalid starred expression", - ))), - ); + return Err(self.error(CompileErrorType::InvalidStarExpr)); } IfExpression { test, body, orelse } => { let no_label = self.new_label(); @@ -2031,21 +2034,33 @@ impl Compiler { } } + let mut compile_element = |element| { + self.compile_expression(element).map_err(|e| { + if matches!(e.error, CompileErrorType::InvalidStarExpr) { + self.error(CompileErrorType::SyntaxError( + "iterable unpacking cannot be used in comprehension".to_owned(), + )) + } else { + e + } + }) + }; + match kind { ast::ComprehensionKind::GeneratorExpression { element } => { - self.compile_expression(element)?; + compile_element(element)?; self.mark_generator(); self.emit(Instruction::YieldValue); self.emit(Instruction::Pop); } ast::ComprehensionKind::List { element } => { - self.compile_expression(element)?; + compile_element(element)?; self.emit(Instruction::ListAppend { i: 1 + generators.len(), }); } ast::ComprehensionKind::Set { element } => { - self.compile_expression(element)?; + compile_element(element)?; self.emit(Instruction::SetAdd { i: 1 + generators.len(), }); diff --git a/compiler/src/error.rs b/compiler/src/error.rs index 85bc9f4625..86635bc951 100644 --- a/compiler/src/error.rs +++ b/compiler/src/error.rs @@ -47,7 +47,9 @@ pub enum CompileErrorType { Parse(ParseErrorType), SyntaxError(String), /// Multiple `*` detected - StarArgs, + MultipleStarArgs, + /// Misplaced `*` expression + InvalidStarExpr, /// Break statement outside of loop. InvalidBreak, /// Continue statement outside of loop. @@ -97,7 +99,10 @@ impl fmt::Display for CompileError { CompileErrorType::ExpectExpr => "Expecting expression, got statement".to_owned(), CompileErrorType::Parse(err) => err.to_string(), CompileErrorType::SyntaxError(err) => err.to_string(), - CompileErrorType::StarArgs => "Two starred expressions in assignment".to_owned(), + CompileErrorType::MultipleStarArgs => { + "two starred expressions in assignment".to_owned() + } + CompileErrorType::InvalidStarExpr => "can't use starred expression here".to_owned(), CompileErrorType::InvalidBreak => "'break' outside loop".to_owned(), CompileErrorType::InvalidContinue => "'continue' outside loop".to_owned(), CompileErrorType::InvalidReturn => "'return' outside function".to_owned(), @@ -120,7 +125,7 @@ impl fmt::Display for CompileError { if self.location.column() > 0 { if let Some(line) = statement.lines().nth(self.location.row() - 1) { // visualize the error, when location and statement are provided - return write!(f, "\n{}\n{}", line, self.location.visualize(&error_desc)); + return write!(f, "{}", self.location.visualize(line, &error_desc)); } } } diff --git a/derive/src/pymodule.rs b/derive/src/pymodule.rs index 1127152139..1b1ee966cf 100644 --- a/derive/src/pymodule.rs +++ b/derive/src/pymodule.rs @@ -1,7 +1,7 @@ use super::Diagnostic; use crate::util::{def_to_name, ItemIdent, ItemMeta}; use proc_macro2::{Span, TokenStream as TokenStream2}; -use quote::{quote, quote_spanned}; +use quote::{quote, quote_spanned, ToTokens}; use std::collections::HashSet; use syn::{parse_quote, spanned::Spanned, Attribute, AttributeArgs, Ident, Item, Meta, NestedMeta}; @@ -156,62 +156,57 @@ fn extract_module_items(mut items: Vec) -> Result Result { - match item { - Item::Mod(mut module) => { - let module_name = def_to_name(&module.ident, "pymodule", attr)?; - - if let Some(content) = module.content.as_mut() { - let items = content - .1 - .iter_mut() - .filter_map(|item| match item { - Item::Fn(syn::ItemFn { attrs, sig, .. }) => Some(ItemIdent { - attrs, - ident: &sig.ident, - }), - Item::Struct(syn::ItemStruct { attrs, ident, .. }) => { - Some(ItemIdent { attrs, ident }) - } - Item::Enum(syn::ItemEnum { attrs, ident, .. }) => { - Some(ItemIdent { attrs, ident }) - } - _ => None, - }) - .collect(); - - let extend_mod = extract_module_items(items)?; - content.1.push(parse_quote! { - const MODULE_NAME: &str = #module_name; - }); - content.1.push(parse_quote! { - pub(crate) fn extend_module( - vm: &::rustpython_vm::vm::VirtualMachine, - module: &::rustpython_vm::pyobject::PyObjectRef, - ) { - #extend_mod - } - }); - content.1.push(parse_quote! { - #[allow(dead_code)] - pub(crate) fn make_module( - vm: &::rustpython_vm::vm::VirtualMachine - ) -> ::rustpython_vm::pyobject::PyObjectRef { - let module = vm.new_module(MODULE_NAME, vm.ctx.new_dict()); - extend_module(vm, &module); - module - } - }); - - Ok(quote! { - #module - }) - } else { - bail_span!( - module, - "#[pymodule] can only be on a module declaration with body" - ) - } - } + let mut module = match item { + Item::Mod(m) => m, other => bail_span!(other, "#[pymodule] can only be on a module declaration"), - } + }; + let module_name = def_to_name(&module.ident, "pymodule", attr)?; + + let (_, content) = match module.content.as_mut() { + Some(c) => c, + None => bail_span!( + module, + "#[pymodule] can only be on a module declaration with body" + ), + }; + + let items = content + .iter_mut() + .filter_map(|item| match item { + Item::Fn(syn::ItemFn { attrs, sig, .. }) => Some(ItemIdent { + attrs, + ident: &sig.ident, + }), + Item::Struct(syn::ItemStruct { attrs, ident, .. }) => Some(ItemIdent { attrs, ident }), + Item::Enum(syn::ItemEnum { attrs, ident, .. }) => Some(ItemIdent { attrs, ident }), + _ => None, + }) + .collect(); + + let extend_mod = extract_module_items(items)?; + content.extend(vec![ + parse_quote! { + const MODULE_NAME: &str = #module_name; + }, + parse_quote! { + pub(crate) fn extend_module( + vm: &::rustpython_vm::vm::VirtualMachine, + module: &::rustpython_vm::pyobject::PyObjectRef, + ) { + #extend_mod + } + }, + parse_quote! { + #[allow(dead_code)] + pub(crate) fn make_module( + vm: &::rustpython_vm::vm::VirtualMachine + ) -> ::rustpython_vm::pyobject::PyObjectRef { + let module = vm.new_module(MODULE_NAME, vm.ctx.new_dict()); + extend_module(vm, &module); + module + } + }, + ]); + + Ok(module.into_token_stream()) } diff --git a/parser/src/error.rs b/parser/src/error.rs index ebe683c8c2..f5935eb3a7 100644 --- a/parser/src/error.rs +++ b/parser/src/error.rs @@ -28,6 +28,7 @@ pub enum LexicalErrorType { UnrecognizedToken { tok: char }, FStringError(FStringErrorType), LineContinuationError, + EOF, OtherError(String), } @@ -59,6 +60,7 @@ impl fmt::Display for LexicalErrorType { LexicalErrorType::LineContinuationError => { write!(f, "unexpected character after line continuation character") } + LexicalErrorType::EOF => write!(f, "unexpected EOF while parsing"), LexicalErrorType::OtherError(msg) => write!(f, "{}", msg), } } diff --git a/parser/src/lexer.rs b/parser/src/lexer.rs index eda96ea741..8a1254fac4 100644 --- a/parser/src/lexer.rs +++ b/parser/src/lexer.rs @@ -1191,6 +1191,13 @@ where location: self.get_pos(), }); } + + if self.chr0.is_none() { + return Err(LexicalError { + error: LexicalErrorType::EOF, + location: self.get_pos(), + }); + } } _ => { diff --git a/parser/src/location.rs b/parser/src/location.rs index 1006d80605..1f543e3e36 100644 --- a/parser/src/location.rs +++ b/parser/src/location.rs @@ -16,13 +16,9 @@ impl fmt::Display for Location { } impl Location { - pub fn visualize(&self, desc: &str) -> String { - format!( - "{}↑\n{}{}", - " ".repeat(self.column - 1), - " ".repeat(self.column - 1), - desc - ) + pub fn visualize(&self, line: &str, desc: &str) -> String { + // desc.to_owned() + format!("{}\n{}\n{}↑", desc, line, " ".repeat(self.column - 1)) } } diff --git a/parser/src/python.lalrpop b/parser/src/python.lalrpop index 548f256e88..c0958a1146 100644 --- a/parser/src/python.lalrpop +++ b/parser/src/python.lalrpop @@ -606,20 +606,8 @@ Path: ast::Expression = { // Decorators: Decorator: ast::Expression = { - "@" "\n" => { - match a { - Some((location, _, arg, _)) => { - ast::Expression { - location, - node: ast::ExpressionType::Call { - function: Box::new(p), - args: arg.args, - keywords: arg.keywords, - } - } - }, - None => p, - } + "@" "\n" => { + p }, }; diff --git a/src/shell.rs b/src/shell.rs index 4a6c095e61..2785c8ae11 100644 --- a/src/shell.rs +++ b/src/shell.rs @@ -1,7 +1,7 @@ mod helper; use rustpython_compiler::{compile, error::CompileError, error::CompileErrorType}; -use rustpython_parser::error::ParseErrorType; +use rustpython_parser::error::{LexicalErrorType, ParseErrorType}; use rustpython_vm::readline::{Readline, ReadlineResult}; use rustpython_vm::{ exceptions::{print_exception, PyBaseExceptionRef}, @@ -23,6 +23,10 @@ fn shell_exec(vm: &VirtualMachine, source: &str, scope: Scope) -> ShellExecResul Ok(_val) => ShellExecResult::Ok, Err(err) => ShellExecResult::PyErr(err), }, + Err(CompileError { + error: CompileErrorType::Parse(ParseErrorType::Lexical(LexicalErrorType::EOF)), + .. + }) => ShellExecResult::Continue, Err(CompileError { error: CompileErrorType::Parse(ParseErrorType::EOF), .. diff --git a/tests/custom_text_test_runner.py b/tests/custom_text_test_runner.py index db2ea4512b..b46157cda9 100644 --- a/tests/custom_text_test_runner.py +++ b/tests/custom_text_test_runner.py @@ -210,6 +210,7 @@ def startTestRun(self): 'num_passed': 0, 'num_failed': 0, 'num_skipped': 0, + 'num_expected_failures': 0, 'execution_time': None} self.suite_number = int(sorted(self.results['suites'].keys())[-1]) + 1 if len(self.results['suites']) else 0 self.case_number = 0 @@ -264,12 +265,14 @@ def startTest(self, test): 'num_passed': 0, 'num_failed': 0, 'num_skipped': 0, + 'num_expected_failures': 0, 'execution_time': None} self.suite_number += 1 self.num_cases = 0 self.num_passed = 0 self.num_failed = 0 self.num_skipped = 0 + self.num_expected_failures = 0 self.results['suites'][self.suite_map[self.suite]]['cases'][self.case_number] = { 'name': self.case, 'method': test._testMethodName, @@ -306,10 +309,12 @@ def stopTest(self, test): self.results['suites'][self.suite_map[self.suite]]['num_passed'] = self.num_passed self.results['suites'][self.suite_map[self.suite]]['num_failed'] = self.num_failed self.results['suites'][self.suite_map[self.suite]]['num_skipped'] = self.num_skipped + self.results['suites'][self.suite_map[self.suite]]['num_expected_failures'] = self.num_expected_failures self.results['suites'][self.suite_map[self.suite]]['cases'][self.current_case_number]['execution_time']= format(self.execution_time, '.%sf' %CustomTextTestResult._execution_time_significant_digits) self.results['num_passed'] += self.num_passed self.results['num_failed'] += self.num_failed self.results['num_skipped'] += self.num_skipped + self.results['num_expected_failures'] += self.num_expected_failures self.case_number += 1 def print_error_string(self, err): @@ -374,7 +379,8 @@ def addExpectedFailure(self, test, err): self.stream.writeln(self.separator_pre_result) self.stream.writeln("EXPECTED FAILURE") self.stream.flush() - self.num_passed += 1 + self.results['suites'][self.suite_map[self.suite]]['cases'][self.current_case_number]['result'] = 'expected_failure' + self.num_expected_failures += 1 self.addScreenshots(test) def addUnexpectedSuccess(self, test): @@ -405,6 +411,7 @@ def printOverallSuiteResults(self, r): 'passed': r[x]['num_passed'], 'failed': r[x]['num_failed'], 'skipped': r[x]['num_skipped'], + 'expected_failures': r[x]['num_expected_failures'], 'percentage': float(r[x]['num_passed'])/(r[x]['num_passed'] + r[x]['num_failed']) * 100 if (r[x]['num_passed'] + r[x]['num_failed']) > 0 else 0, 'time': r[x]['execution_time']}) total_suites_passed = len([x for x in data if not x['failed']]) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 4e60049400..7d65abf162 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -158,3 +158,9 @@ def __eq__(self, other): assert Complex(4, 5) - 3 == Complex(1, 5) assert 7 - Complex(4, 5) == Complex(3, -5) + +assert complex("5+2j") == 5 + 2j +assert complex("5-2j") == 5 - 2j +assert complex("-2j") == -2j +assert_raises(TypeError, lambda: complex("5+2j", 1)) +assert_raises(ValueError, lambda: complex("abc")) diff --git a/tests/snippets/dict_union.py b/tests/snippets/dict_union.py new file mode 100644 index 0000000000..29e0718d45 --- /dev/null +++ b/tests/snippets/dict_union.py @@ -0,0 +1,83 @@ + +import testutils + +def test_dunion_ior0(): + a={1:2,2:3} + b={3:4,5:6} + a|=b + + assert a == {1:2,2:3,3:4,5:6}, f"wrong value assigned {a=}" + assert b == {3:4,5:6}, f"right hand side modified, {b=}" + +def test_dunion_or0(): + a={1:2,2:3} + b={3:4,5:6} + c=a|b + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_or1(): + a={1:2,2:3} + b={3:4,5:6} + c=a.__or__(b) + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_ror0(): + a={1:2,2:3} + b={3:4,5:6} + c=b.__ror__(a) + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_other_types(): + def perf_test_or(other_obj): + d={1:2} + try: + d.__or__(other_obj) + except: + return True + return False + + def perf_test_ior(other_obj): + d={1:2} + try: + d.__ior__(other_obj) + except: + return True + return False + + def perf_test_ror(other_obj): + d={1:2} + try: + d.__ror__(other_obj) + except: + return True + return False + + test_fct={'__or__':perf_test_or, '__ror__':perf_test_ror, '__ior__':perf_test_ior} + others=['FooBar', 42, [36], set([19]), ['aa'], None] + for tfn,tf in test_fct.items(): + for other in others: + assert tf(other), f"Failed: dict {tfn}, accepted {other}" + + + + +testutils.skip_if_unsupported(3,9,test_dunion_ior0) +testutils.skip_if_unsupported(3,9,test_dunion_or0) +testutils.skip_if_unsupported(3,9,test_dunion_or1) +testutils.skip_if_unsupported(3,9,test_dunion_ror0) +testutils.skip_if_unsupported(3,9,test_dunion_other_types) + + + diff --git a/tests/snippets/floats.py b/tests/snippets/floats.py index def52fedb8..ce39be708b 100644 --- a/tests/snippets/floats.py +++ b/tests/snippets/floats.py @@ -490,6 +490,8 @@ def identical(x, y): assert float('0_0') == 0.0 assert float('.0') == 0.0 assert float('0.') == 0.0 +assert float('-.0') == 0.0 +assert float('+.0') == 0.0 assert_raises(ValueError, lambda: float('0._0')) assert_raises(ValueError, lambda: float('0_.0')) diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 7ceb653c1c..6c40903d44 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -475,7 +475,7 @@ def try_mutate_str(): # remove*fix test def test_removeprefix(): - s='foobarfoo' + s = 'foobarfoo' s_ref='foobarfoo' assert s.removeprefix('f') == s_ref[1:] assert s.removeprefix('fo') == s_ref[2:] @@ -488,9 +488,24 @@ def test_removeprefix(): assert s.removeprefix('-foo') == s_ref assert s.removeprefix('afoo') == s_ref assert s.removeprefix('*foo') == s_ref - + assert s==s_ref, 'undefined test fail' + s_uc = '😱foobarfoo🖖' + s_ref_uc = '😱foobarfoo🖖' + assert s_uc.removeprefix('😱') == s_ref_uc[1:] + assert s_uc.removeprefix('😱fo') == s_ref_uc[3:] + assert s_uc.removeprefix('😱foo') == s_ref_uc[4:] + + assert s_uc.removeprefix('🖖') == s_ref_uc + assert s_uc.removeprefix('foo') == s_ref_uc + assert s_uc.removeprefix(' ') == s_ref_uc + assert s_uc.removeprefix('_😱') == s_ref_uc + assert s_uc.removeprefix(' 😱') == s_ref_uc + assert s_uc.removeprefix('-😱') == s_ref_uc + assert s_uc.removeprefix('#😱') == s_ref_uc + + def test_removeprefix_types(): s='0123456' s_ref='0123456' @@ -501,7 +516,7 @@ def test_removeprefix_types(): s.removeprefix(o) except: found=True - + assert found, f'Removeprefix accepts other type: {type(o)}: {o=}' def test_removesuffix(): @@ -518,9 +533,23 @@ def test_removesuffix(): assert s.removesuffix('foo-') == s_ref assert s.removesuffix('foo*') == s_ref assert s.removesuffix('fooa') == s_ref - assert s==s_ref, 'undefined test fail' + s_uc = '😱foobarfoo🖖' + s_ref_uc = '😱foobarfoo🖖' + assert s_uc.removesuffix('🖖') == s_ref_uc[:-1] + assert s_uc.removesuffix('oo🖖') == s_ref_uc[:-3] + assert s_uc.removesuffix('foo🖖') == s_ref_uc[:-4] + + assert s_uc.removesuffix('😱') == s_ref_uc + assert s_uc.removesuffix('foo') == s_ref_uc + assert s_uc.removesuffix(' ') == s_ref_uc + assert s_uc.removesuffix('🖖_') == s_ref_uc + assert s_uc.removesuffix('🖖 ') == s_ref_uc + assert s_uc.removesuffix('🖖-') == s_ref_uc + assert s_uc.removesuffix('🖖#') == s_ref_uc + + def test_removesuffix_types(): s='0123456' s_ref='0123456' @@ -531,13 +560,10 @@ def test_removesuffix_types(): s.removesuffix(o) except: found=True - assert found, f'Removesuffix accepts other type: {type(o)}: {o=}' - skip_if_unsupported(3,9,test_removeprefix) skip_if_unsupported(3,9,test_removeprefix_types) skip_if_unsupported(3,9,test_removesuffix) skip_if_unsupported(3,9,test_removesuffix_types) - diff --git a/tests/snippets/test_threading.py b/tests/snippets/test_threading.py index e00d1835fd..41024b360e 100644 --- a/tests/snippets/test_threading.py +++ b/tests/snippets/test_threading.py @@ -5,17 +5,20 @@ def thread_function(name): - output.append("Thread %s: starting" % name) + output.append((name, 0)) time.sleep(2.0) - output.append("Thread %s: finishing" % name) + output.append((name, 1)) -output.append("Main : before creating thread") +output.append((0, 0)) x = threading.Thread(target=thread_function, args=(1, )) -output.append("Main : before running thread") +output.append((0, 1)) x.start() -output.append("Main : wait for the thread to finish") +output.append((0, 2)) x.join() -output.append("Main : all done") +output.append((0, 3)) assert len(output) == 6, output +# CPython has [(1, 0), (0, 2)] for the middle 2, but we have [(0, 2), (1, 0)] +# TODO: maybe fix this, if it turns out to be a problem? +# assert output == [(0, 0), (0, 1), (1, 0), (0, 2), (1, 1), (0, 3)] diff --git a/tests/snippets/testutils.py b/tests/snippets/testutils.py index c779d2c898..f1eb4d9cdf 100644 --- a/tests/snippets/testutils.py +++ b/tests/snippets/testutils.py @@ -91,5 +91,4 @@ def exec(): elif sys.version_info.major>=req_maj_vers and sys.version_info.minor>=req_min_vers: exec() else: - assert False, f'Test cannot performed on this python version. {platform.python_implementation()} {platform.python_version()}' - + assert False, f'Test cannot performed on this python version. {platform.python_implementation()} {platform.python_version()}' \ No newline at end of file diff --git a/tests/snippets/tuple.py b/tests/snippets/tuple.py index 0f4306fa61..fd59e90609 100644 --- a/tests/snippets/tuple.py +++ b/tests/snippets/tuple.py @@ -44,7 +44,7 @@ def __eq__(self, x): b = (55, *a) assert b == (55, 1, 2, 3, 1) -assert () is () +assert () is () # noqa a = () b = () diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 71c179e3e0..28b0cc435d 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -70,6 +70,8 @@ num_enum = "0.4" smallbox = "0.8" bstr = "0.2.12" crossbeam-utils = "0.7" +generational-arena = "0.2" +parking_lot = { git = "https://github.com/Amanieu/parking_lot" } # TODO: use published version ## unicode stuff unicode_names2 = "0.4" diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 1f5f08bd56..c6838616a8 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -736,7 +736,7 @@ fn builtin_sum(iterable: PyIterable, start: OptionalArg, vm: &VirtualMachine) -> // Should be renamed to builtin___import__? fn builtin_import(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - vm.invoke(&vm.import_func.borrow(), args) + vm.invoke(&vm.import_func, args) } fn builtin_vars(obj: OptionalArg, vm: &VirtualMachine) -> PyResult { @@ -761,7 +761,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) { }); } - let debug_mode: bool = vm.settings.optimize == 0; + let debug_mode: bool = vm.state.settings.optimize == 0; extend_module!(vm, module, { "__debug__" => ctx.new_bool(debug_mode), //set __name__ fixes: https://github.com/RustPython/RustPython/issues/146 diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 43a9fceb04..41576c8a5b 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -1,6 +1,6 @@ use crate::obj::objstr::PyString; use crate::pyhash; -use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult, ThreadSafe}; +use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult}; use crate::vm::VirtualMachine; use num_bigint::ToBigInt; /// Ordered dictionary implementation. @@ -23,8 +23,6 @@ pub struct Dict { inner: RwLock>, } -impl ThreadSafe for Dict {} - struct InnerDict { size: usize, indices: HashMap, diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 05d1a370fa..bf1fa0b8aa 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -6,8 +6,8 @@ use crate::obj::objtuple::{PyTuple, PyTupleRef}; use crate::obj::objtype::{self, PyClass, PyClassRef}; use crate::py_serde; use crate::pyobject::{ - PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TryFromObject, TypeProtocol, + PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::slots::PyTpFlags; use crate::types::create_type; @@ -30,8 +30,6 @@ pub struct PyBaseException { args: RwLock, } -impl ThreadSafe for PyBaseException {} - impl fmt::Debug for PyBaseException { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // TODO: implement more detailed, non-recursive Debug formatter @@ -645,6 +643,10 @@ pub fn init(ctx: &PyContext) { "text" => ctx.none(), }); + extend_class!(ctx, &excs.system_exit, { + "code" => ctx.new_readonly_getset("code", make_arg_getter(0)), + }); + extend_class!(ctx, &excs.import_error, { "__init__" => ctx.new_method(import_error_init), "msg" => ctx.new_readonly_getset("msg", make_arg_getter(0)), diff --git a/vm/src/frame.rs b/vm/src/frame.rs index e522a8b4f3..0b4416fdb8 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -940,7 +940,9 @@ impl ExecutingFrame<'_> { if unpack { for obj in self.pop_multiple(size) { // Take all key-value pairs from the dict: - let dict: PyDictRef = obj.downcast().expect("Need a dictionary to build a map."); + let dict: PyDictRef = obj.downcast().map_err(|obj| { + vm.new_type_error(format!("'{}' object is not a mapping", obj.class().name)) + })?; for (key, value) in dict { if for_call { if map_obj.contains_key(&key, vm) { @@ -1014,6 +1016,7 @@ impl ExecutingFrame<'_> { let kwargs = if *has_kwargs { let kw_dict: PyDictRef = match self.pop_value().downcast() { Err(_) => { + // TODO: check collections.abc.Mapping return Err(vm.new_type_error("Kwargs must be a dict.".to_owned())); } Ok(x) => x, @@ -1137,7 +1140,7 @@ impl ExecutingFrame<'_> { let min_expected = before + after; if elements.len() < min_expected { Err(vm.new_value_error(format!( - "Not enough values to unpack (expected at least {}, got {}", + "not enough values to unpack (expected at least {}, got {})", min_expected, elements.len() ))) diff --git a/vm/src/function.rs b/vm/src/function.rs index 48cceb8f1d..67e2dc0455 100644 --- a/vm/src/function.rs +++ b/vm/src/function.rs @@ -269,6 +269,11 @@ impl KwArgs { self.0.remove(name) } } +impl From> for KwArgs { + fn from(map: HashMap) -> Self { + KwArgs(map) + } +} impl FromArgs for KwArgs where diff --git a/vm/src/import.rs b/vm/src/import.rs index bf07f83623..d865c20602 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -14,14 +14,13 @@ use crate::vm::{InitParameter, VirtualMachine}; #[cfg(feature = "rustpython-compiler")] use rustpython_compiler::compile; -pub fn init_importlib(vm: &VirtualMachine, initialize_parameter: InitParameter) -> PyResult { +pub fn init_importlib(vm: &mut VirtualMachine, initialize_parameter: InitParameter) -> PyResult { flame_guard!("init importlib"); let importlib = import_frozen(vm, "_frozen_importlib")?; let impmod = import_builtin(vm, "_imp")?; let install = vm.get_attribute(importlib.clone(), "_install")?; vm.invoke(&install, vec![vm.sys_module.clone(), impmod])?; - vm.import_func - .replace(vm.get_attribute(importlib.clone(), "__import__")?); + vm.import_func = vm.get_attribute(importlib.clone(), "__import__")?; match initialize_parameter { InitParameter::InitializeExternal if cfg!(feature = "rustpython-compiler") => { @@ -58,16 +57,16 @@ pub fn init_importlib(vm: &VirtualMachine, initialize_parameter: InitParameter) } pub fn import_frozen(vm: &VirtualMachine, module_name: &str) -> PyResult { - vm.frozen - .borrow() + vm.state + .frozen .get(module_name) .ok_or_else(|| vm.new_import_error(format!("Cannot import frozen module {}", module_name))) .and_then(|frozen| import_codeobj(vm, module_name, frozen.code.clone(), false)) } pub fn import_builtin(vm: &VirtualMachine, module_name: &str) -> PyResult { - vm.stdlib_inits - .borrow() + vm.state + .stdlib_inits .get(module_name) .ok_or_else(|| vm.new_import_error(format!("Cannot import bultin module {}", module_name))) .and_then(|make_module_func| { diff --git a/vm/src/obj/objasyncgenerator.rs b/vm/src/obj/objasyncgenerator.rs index c26e2a7c0e..9a377adaf3 100644 --- a/vm/src/obj/objasyncgenerator.rs +++ b/vm/src/obj/objasyncgenerator.rs @@ -4,7 +4,7 @@ use super::objtype::{self, PyClassRef}; use crate::exceptions::PyBaseExceptionRef; use crate::frame::FrameRef; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use crossbeam_utils::atomic::AtomicCell; @@ -16,7 +16,6 @@ pub struct PyAsyncGen { running_async: AtomicCell, } pub type PyAsyncGenRef = PyRef; -impl ThreadSafe for PyAsyncGen {} impl PyValue for PyAsyncGen { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -164,8 +163,6 @@ struct PyAsyncGenASend { value: PyObjectRef, } -impl ThreadSafe for PyAsyncGenASend {} - impl PyValue for PyAsyncGenASend { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.types.async_generator_asend.clone() @@ -262,8 +259,6 @@ struct PyAsyncGenAThrow { value: (PyObjectRef, PyObjectRef, PyObjectRef), } -impl ThreadSafe for PyAsyncGenAThrow {} - impl PyValue for PyAsyncGenAThrow { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.types.async_generator_athrow.clone() diff --git a/vm/src/obj/objbuiltinfunc.rs b/vm/src/obj/objbuiltinfunc.rs index 9a459efcc7..9b91ce0681 100644 --- a/vm/src/obj/objbuiltinfunc.rs +++ b/vm/src/obj/objbuiltinfunc.rs @@ -3,7 +3,7 @@ use std::fmt; use crate::function::{OptionalArg, PyFuncArgs, PyNativeFunc}; use crate::obj::objtype::PyClassRef; use crate::pyobject::{ - IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyResult, PyValue, ThreadSafe, TypeProtocol, + IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyResult, PyValue, TypeProtocol, }; use crate::slots::{SlotCall, SlotDescriptor}; use crate::vm::VirtualMachine; @@ -12,7 +12,6 @@ use crate::vm::VirtualMachine; pub struct PyBuiltinFunction { value: PyNativeFunc, } -impl ThreadSafe for PyBuiltinFunction {} impl PyValue for PyBuiltinFunction { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -49,7 +48,6 @@ impl PyBuiltinFunction {} pub struct PyBuiltinMethod { function: PyBuiltinFunction, } -impl ThreadSafe for PyBuiltinMethod {} impl PyValue for PyBuiltinMethod { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 029be2d1c9..3cdf7d6344 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -21,7 +21,7 @@ use crate::function::{OptionalArg, OptionalOption}; use crate::obj::objstr::do_cformat_string; use crate::pyobject::{ Either, PyClassImpl, PyComparisonValue, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, - PyValue, ThreadSafe, TryFromObject, TypeProtocol, + PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -42,8 +42,6 @@ pub struct PyByteArray { inner: RwLock, } -impl ThreadSafe for PyByteArray {} - pub type PyByteArrayRef = PyRef; impl PyByteArray { @@ -389,6 +387,30 @@ impl PyByteArray { self.borrow_value().rstrip(chars).into() } + /// removeprefix($self, prefix, /) + /// + /// + /// Return a bytearray object with the given prefix string removed if present. + /// + /// If the bytearray starts with the prefix string, return string[len(prefix):] + /// Otherwise, return a copy of the original bytearray. + #[pymethod(name = "removeprefix")] + fn removeprefix(&self, prefix: PyByteInner) -> PyByteArray { + self.borrow_value().removeprefix(prefix).into() + } + + /// removesuffix(self, prefix, /) + /// + /// + /// Return a bytearray object with the given suffix string removed if present. + /// + /// If the bytearray ends with the suffix string, return string[:len(suffix)] + /// Otherwise, return a copy of the original bytearray. + #[pymethod(name = "removesuffix")] + fn removesuffix(&self, suffix: PyByteInner) -> PyByteArray { + self.borrow_value().removesuffix(suffix).to_vec().into() + } + #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { self.borrow_value() diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 38f6f0b718..acb8889dfe 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -17,8 +17,7 @@ use super::pystr::{self, PyCommonString, PyCommonStringWrapper}; use crate::function::{OptionalArg, OptionalOption}; use crate::pyhash; use crate::pyobject::{ - Either, PyComparisonValue, PyIterable, PyObjectRef, PyResult, ThreadSafe, TryFromObject, - TypeProtocol, + Either, PyComparisonValue, PyIterable, PyObjectRef, PyResult, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -33,8 +32,6 @@ impl From> for PyByteInner { } } -impl ThreadSafe for PyByteInner {} - impl TryFromObject for PyByteInner { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { match_class!(match obj { @@ -843,6 +840,24 @@ impl PyByteInner { .to_vec() } + // new in Python 3.9 + pub fn removeprefix(&self, prefix: PyByteInner) -> Vec { + self.elements + .py_removeprefix(&prefix.elements, prefix.elements.len(), |s, p| { + s.starts_with(p) + }) + .to_vec() + } + + // new in Python 3.9 + pub fn removesuffix(&self, suffix: PyByteInner) -> Vec { + self.elements + .py_removesuffix(&suffix.elements, suffix.elements.len(), |s, p| { + s.ends_with(p) + }) + .to_vec() + } + pub fn split( &self, options: ByteInnerSplitOptions, diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index bf76f49a2a..6e6882d1d4 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -22,7 +22,7 @@ use crate::pyobject::{ Either, IntoPyObject, PyArithmaticValue::{self, *}, PyClassImpl, PyComparisonValue, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, - ThreadSafe, TryFromObject, TypeProtocol, + TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -41,8 +41,6 @@ pub struct PyBytes { inner: PyByteInner, } -impl ThreadSafe for PyBytes {} - pub type PyBytesRef = PyRef; impl PyBytes { @@ -347,6 +345,30 @@ impl PyBytes { self.inner.rstrip(chars).into() } + /// removeprefix($self, prefix, /) + /// + /// + /// Return a bytes object with the given prefix string removed if present. + /// + /// If the bytes starts with the prefix string, return string[len(prefix):] + /// Otherwise, return a copy of the original bytes. + #[pymethod(name = "removeprefix")] + fn removeprefix(&self, prefix: PyByteInner) -> PyBytes { + self.inner.removeprefix(prefix).into() + } + + /// removesuffix(self, prefix, /) + /// + /// + /// Return a bytes object with the given suffix string removed if present. + /// + /// If the bytes ends with the suffix string, return string[:len(suffix)] + /// Otherwise, return a copy of the original bytes. + #[pymethod(name = "removesuffix")] + fn removesuffix(&self, suffix: PyByteInner) -> PyBytes { + self.inner.removesuffix(suffix).into() + } + #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { self.inner diff --git a/vm/src/obj/objclassmethod.rs b/vm/src/obj/objclassmethod.rs index 84af110048..edc579698c 100644 --- a/vm/src/obj/objclassmethod.rs +++ b/vm/src/obj/objclassmethod.rs @@ -1,7 +1,7 @@ use super::objtype::PyClassRef; use crate::function::OptionalArg; use crate::pyobject::{ - PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TypeProtocol, + PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; use crate::slots::SlotDescriptor; use crate::vm::VirtualMachine; @@ -32,7 +32,6 @@ pub struct PyClassMethod { callable: PyObjectRef, } pub type PyClassMethodRef = PyRef; -impl ThreadSafe for PyClassMethod {} impl PyClassMethod { pub fn new(value: PyObjectRef) -> Self { diff --git a/vm/src/obj/objcode.rs b/vm/src/obj/objcode.rs index 5f4a36bf9f..0bf11c679d 100644 --- a/vm/src/obj/objcode.rs +++ b/vm/src/obj/objcode.rs @@ -7,9 +7,7 @@ use std::ops::Deref; use super::objtype::PyClassRef; use crate::bytecode; -use crate::pyobject::{ - IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, -}; +use crate::pyobject::{IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; pub type PyCodeRef = PyRef; @@ -19,8 +17,6 @@ pub struct PyCode { pub code: bytecode::CodeObject, } -impl ThreadSafe for PyCode {} - impl Deref for PyCode { type Target = bytecode::CodeObject; fn deref(&self) -> &Self::Target { diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 6a05ebc776..4b7f04de0c 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -1,13 +1,16 @@ use num_complex::Complex64; use num_traits::Zero; use std::num::Wrapping; +use std::str::FromStr; -use super::objfloat::{self, IntoPyFloat}; +use super::objfloat::{self, IntoPyFloat, PyFloat}; +use super::objint::{self, PyInt}; +use super::objstr::PyString; use super::objtype::PyClassRef; use crate::function::OptionalArg; use crate::pyhash; use crate::pyobject::{ - IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, + IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -20,8 +23,6 @@ pub struct PyComplex { value: Complex64, } -impl ThreadSafe for PyComplex {} - type PyComplexRef = PyRef; impl PyValue for PyComplex { @@ -224,13 +225,36 @@ impl PyComplex { #[pyslot] fn tp_new( cls: PyClassRef, - real: OptionalArg, + real: OptionalArg, imag: OptionalArg, vm: &VirtualMachine, ) -> PyResult { let real = match real { OptionalArg::Missing => 0.0, - OptionalArg::Present(ref value) => value.to_f64(), + OptionalArg::Present(obj) => match_class!(match obj { + i @ PyInt => { + objint::try_float(i.as_bigint(), vm)? + } + f @ PyFloat => { + f.to_f64() + } + s @ PyString => { + if imag.into_option().is_some() { + return Err(vm.new_type_error( + "complex() can't take second arg if first is a string".to_owned(), + )); + } + let value = Complex64::from_str(s.as_str()) + .map_err(|err| vm.new_value_error(err.to_string()))?; + return PyComplex { value }.into_ref_with_type(vm, cls); + } + obj => { + return Err(vm.new_type_error(format!( + "complex() first argument must be a string or a number, not '{}'", + obj.class() + ))); + } + }), }; let imag = match imag { diff --git a/vm/src/obj/objcoroinner.rs b/vm/src/obj/objcoroinner.rs index 4cbb0940b4..33b1c2f2af 100644 --- a/vm/src/obj/objcoroinner.rs +++ b/vm/src/obj/objcoroinner.rs @@ -1,7 +1,7 @@ use super::objtype::{self, PyClassRef}; use crate::exceptions::{self, PyBaseExceptionRef}; use crate::frame::{ExecutionResult, FrameRef}; -use crate::pyobject::{PyObjectRef, PyResult, ThreadSafe}; +use crate::pyobject::{PyObjectRef, PyResult}; use crate::vm::VirtualMachine; use crossbeam_utils::atomic::AtomicCell; @@ -42,8 +42,6 @@ pub struct Coro { variant: Variant, } -impl ThreadSafe for Coro {} - impl Coro { pub fn new(frame: FrameRef, variant: Variant) -> Self { Coro { diff --git a/vm/src/obj/objcoroutine.rs b/vm/src/obj/objcoroutine.rs index 48ccfe4dc1..d1faacc2be 100644 --- a/vm/src/obj/objcoroutine.rs +++ b/vm/src/obj/objcoroutine.rs @@ -4,7 +4,7 @@ use super::objstr::PyStringRef; use super::objtype::PyClassRef; use crate::frame::FrameRef; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; pub type PyCoroutineRef = PyRef; @@ -15,8 +15,6 @@ pub struct PyCoroutine { inner: Coro, } -impl ThreadSafe for PyCoroutine {} - impl PyValue for PyCoroutine { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.types.coroutine_type.clone() @@ -103,8 +101,6 @@ pub struct PyCoroutineWrapper { coro: PyCoroutineRef, } -impl ThreadSafe for PyCoroutineWrapper {} - impl PyValue for PyCoroutineWrapper { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.types.coroutine_wrapper_type.clone() diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index e2cdb108db..c9b088b03d 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -1,6 +1,7 @@ -use std::cell::Cell; use std::fmt; +use crossbeam_utils::atomic::AtomicCell; + use super::objiter; use super::objstr; use super::objtype::{self, PyClassRef}; @@ -9,7 +10,7 @@ use crate::exceptions::PyBaseExceptionRef; use crate::function::{KwArgs, OptionalArg, PyFuncArgs}; use crate::pyobject::{ IdProtocol, IntoPyObject, ItemProtocol, PyAttributes, PyClassImpl, PyContext, PyIterable, - PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, + PyObjectRef, PyRef, PyResult, PyValue, }; use crate::vm::{ReprGuard, VirtualMachine}; @@ -23,7 +24,6 @@ pub struct PyDict { entries: DictContentType, } pub type PyDictRef = PyRef; -impl ThreadSafe for PyDict {} impl fmt::Debug for PyDict { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -585,14 +585,14 @@ macro_rules! dict_iterator { struct $iter_name { pub dict: PyDictRef, pub size: dictdatatype::DictSize, - pub position: Cell, + pub position: AtomicCell, } #[pyimpl] impl $iter_name { fn new(dict: PyDictRef) -> Self { $iter_name { - position: Cell::new(0), + position: AtomicCell::new(0), size: dict.size(), dict, } @@ -601,15 +601,15 @@ macro_rules! dict_iterator { #[pymethod(name = "__next__")] #[allow(clippy::redundant_closure_call)] fn next(&self, vm: &VirtualMachine) -> PyResult { - let mut position = self.position.get(); if self.dict.entries.has_changed_size(&self.size) { return Err( vm.new_runtime_error("dictionary changed size during iteration".to_owned()) ); } + let mut position = self.position.load(); match self.dict.entries.next_entry(&mut position) { Some((key, value)) => { - self.position.set(position); + self.position.store(position); Ok($result_fn(vm, key, value)) } None => Err(objiter::new_stop_iteration(vm)), @@ -623,7 +623,7 @@ macro_rules! dict_iterator { #[pymethod(name = "__length_hint__")] fn length_hint(&self) -> usize { - self.dict.entries.len_from_entry_index(self.position.get()) + self.dict.entries.len_from_entry_index(self.position.load()) } } diff --git a/vm/src/obj/objenumerate.rs b/vm/src/obj/objenumerate.rs index 3c6c864c2e..6966ae7fe6 100644 --- a/vm/src/obj/objenumerate.rs +++ b/vm/src/obj/objenumerate.rs @@ -8,7 +8,7 @@ use super::objint::PyIntRef; use super::objiter; use super::objtype::PyClassRef; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; #[pyclass] @@ -18,7 +18,6 @@ pub struct PyEnumerate { iterator: PyObjectRef, } type PyEnumerateRef = PyRef; -impl ThreadSafe for PyEnumerate {} impl PyValue for PyEnumerate { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objfilter.rs b/vm/src/obj/objfilter.rs index f5f0e4172a..1b3f2ae654 100644 --- a/vm/src/obj/objfilter.rs +++ b/vm/src/obj/objfilter.rs @@ -1,9 +1,7 @@ use super::objbool; use super::objiter; use super::objtype::PyClassRef; -use crate::pyobject::{ - IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, -}; +use crate::pyobject::{IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; pub type PyFilterRef = PyRef; @@ -18,7 +16,6 @@ pub struct PyFilter { predicate: PyObjectRef, iterator: PyObjectRef, } -impl ThreadSafe for PyFilter {} impl PyValue for PyFilter { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index 0ffeba076a..bc7f4dff28 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -12,7 +12,7 @@ use crate::function::{OptionalArg, OptionalOption}; use crate::pyhash; use crate::pyobject::{ IntoPyObject, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, - PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, TypeProtocol, + PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -23,8 +23,6 @@ pub struct PyFloat { value: f64, } -impl ThreadSafe for PyFloat {} - impl PyFloat { pub fn to_f64(self) -> f64 { self.value @@ -612,7 +610,7 @@ fn str_to_float(vm: &VirtualMachine, literal: &str) -> PyResult { if !c.is_ascii_alphanumeric() { if let Some(l) = last_tok { - if !l.is_ascii_alphanumeric() { + if !l.is_ascii_alphanumeric() && !(c == '.' && (l == '-' || l == '+')) { return Err(invalid_convert(vm, literal)); } } diff --git a/vm/src/obj/objfunction.rs b/vm/src/obj/objfunction.rs index c6d24fe6f1..305fde3674 100644 --- a/vm/src/obj/objfunction.rs +++ b/vm/src/obj/objfunction.rs @@ -11,7 +11,7 @@ use crate::obj::objcoroutine::PyCoroutine; use crate::obj::objgenerator::PyGenerator; use crate::pyobject::{ IdProtocol, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, - ThreadSafe, TypeProtocol, + TypeProtocol, }; use crate::scope::Scope; use crate::slots::{SlotCall, SlotDescriptor}; @@ -28,7 +28,6 @@ pub struct PyFunction { defaults: Option, kw_only_defaults: Option, } -impl ThreadSafe for PyFunction {} impl SlotDescriptor for PyFunction { fn descr_get( @@ -294,8 +293,6 @@ pub struct PyBoundMethod { pub function: PyObjectRef, } -impl ThreadSafe for PyBoundMethod {} - impl SlotCall for PyBoundMethod { fn call(&self, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { let args = args.insert(self.object.clone()); diff --git a/vm/src/obj/objgenerator.rs b/vm/src/obj/objgenerator.rs index b2044621f9..e91e24406c 100644 --- a/vm/src/obj/objgenerator.rs +++ b/vm/src/obj/objgenerator.rs @@ -7,7 +7,7 @@ use super::objcoroinner::{Coro, Variant}; use super::objtype::PyClassRef; use crate::frame::FrameRef; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; pub type PyGeneratorRef = PyRef; @@ -18,8 +18,6 @@ pub struct PyGenerator { inner: Coro, } -impl ThreadSafe for PyGenerator {} - impl PyValue for PyGenerator { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.generator_type() diff --git a/vm/src/obj/objgetset.rs b/vm/src/obj/objgetset.rs index 83f69110e7..00f0f25ae7 100644 --- a/vm/src/obj/objgetset.rs +++ b/vm/src/obj/objgetset.rs @@ -4,8 +4,8 @@ use super::objtype::PyClassRef; use crate::function::{FunctionBox, OptionalArg, OwnedParam, RefParam}; use crate::pyobject::{ - IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TryFromObject, TypeProtocol, + IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::slots::SlotDescriptor; use crate::vm::VirtualMachine; @@ -152,8 +152,6 @@ pub struct PyGetSet { // doc: Option, } -impl ThreadSafe for PyGetSet {} - impl std::fmt::Debug for PyGetSet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 746df5c7b4..271652b5c1 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -18,7 +18,7 @@ use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyhash; use crate::pyobject::{ IdProtocol, IntoPyObject, PyArithmaticValue, PyClassImpl, PyComparisonValue, PyContext, - PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, TypeProtocol, + PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::stdlib::array::PyArray; use crate::vm::VirtualMachine; @@ -43,8 +43,6 @@ pub struct PyInt { value: BigInt, } -impl ThreadSafe for PyInt {} - impl fmt::Display for PyInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { BigInt::fmt(&self.value, f) @@ -655,6 +653,13 @@ impl PyInt { fn denominator(&self) -> usize { 1 } + + #[pymethod] + /// Returns the number of ones 1 an int. When the number is < 0, + /// then it returns the number of ones of the absolute value. + fn bit_count(&self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bigint(&BigInt::from(self.value.to_u32_digits().1.iter().map(|n|n.count_ones()).sum::()))) + } } #[derive(FromArgs)] diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 654c85467a..675df6f33a 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -10,8 +10,8 @@ use super::objsequence; use super::objtype::{self, PyClassRef}; use crate::exceptions::PyBaseExceptionRef; use crate::pyobject::{ - PyCallable, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TryFromObject, TypeProtocol, + PyCallable, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -26,7 +26,7 @@ pub fn get_iter(vm: &VirtualMachine, iter_target: &PyObjectRef) -> PyResult { vm.invoke(&method, vec![]) } else { vm.get_method_or_type_error(iter_target.clone(), "__getitem__", || { - format!("Cannot iterate over {}", iter_target.class().name) + format!("'{}' object is not iterable", iter_target.class().name) })?; Ok(PySequenceIterator::new_forward(iter_target.clone()) .into_ref(vm) @@ -145,7 +145,6 @@ pub struct PySequenceIterator { pub obj: PyObjectRef, pub reversed: bool, } -impl ThreadSafe for PySequenceIterator {} impl PyValue for PySequenceIterator { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -219,7 +218,6 @@ pub struct PyCallableIterator { sentinel: PyObjectRef, done: AtomicCell, } -impl ThreadSafe for PyCallableIterator {} impl PyValue for PyCallableIterator { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index 1f3169cc8c..907de91d64 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -17,7 +17,7 @@ use super::objtype::PyClassRef; use crate::function::OptionalArg; use crate::pyobject::{ IdProtocol, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyContext, PyIterable, - PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, TypeProtocol, + PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::sequence::{self, SimpleSeq}; use crate::vm::{ReprGuard, VirtualMachine}; @@ -32,8 +32,6 @@ pub struct PyList { elements: RwLock>, } -impl ThreadSafe for PyList {} - impl fmt::Debug for PyList { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // TODO: implement more detailed, non-recursive Debug formatter @@ -876,8 +874,6 @@ pub struct PyListIterator { pub list: PyListRef, } -impl ThreadSafe for PyListIterator {} - impl PyValue for PyListIterator { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.listiterator_type() @@ -917,8 +913,6 @@ pub struct PyListReverseIterator { pub list: PyListRef, } -impl ThreadSafe for PyListReverseIterator {} - impl PyValue for PyListReverseIterator { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.listreverseiterator_type() diff --git a/vm/src/obj/objmap.rs b/vm/src/obj/objmap.rs index 9b6e25adaf..a90b1c025d 100644 --- a/vm/src/obj/objmap.rs +++ b/vm/src/obj/objmap.rs @@ -1,7 +1,7 @@ use super::objiter; use super::objtype::PyClassRef; use crate::function::Args; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; /// map(func, *iterables) --> map object @@ -15,7 +15,6 @@ pub struct PyMap { iterators: Vec, } type PyMapRef = PyRef; -impl ThreadSafe for PyMap {} impl PyValue for PyMap { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objmappingproxy.rs b/vm/src/obj/objmappingproxy.rs index 47b447a2da..e2b068162b 100644 --- a/vm/src/obj/objmappingproxy.rs +++ b/vm/src/obj/objmappingproxy.rs @@ -4,8 +4,7 @@ use super::objstr::PyStringRef; use super::objtype::PyClassRef; use crate::function::OptionalArg; use crate::pyobject::{ - ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TryFromObject, + ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, }; use crate::vm::VirtualMachine; @@ -14,7 +13,6 @@ use crate::vm::VirtualMachine; pub struct PyMappingProxy { mapping: MappingProxyInner, } -impl ThreadSafe for PyMappingProxy {} #[derive(Debug)] enum MappingProxyInner { diff --git a/vm/src/obj/objmemory.rs b/vm/src/obj/objmemory.rs index 4add6e329f..41ccd6c893 100644 --- a/vm/src/obj/objmemory.rs +++ b/vm/src/obj/objmemory.rs @@ -1,7 +1,7 @@ use super::objbyteinner::try_as_byte; use super::objtype::{issubclass, PyClassRef}; use crate::pyobject::{ - PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TypeProtocol, + PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; use crate::stdlib::array::PyArray; use crate::vm::VirtualMachine; @@ -11,7 +11,6 @@ use crate::vm::VirtualMachine; pub struct PyMemoryView { obj_ref: PyObjectRef, } -impl ThreadSafe for PyMemoryView {} pub type PyMemoryViewRef = PyRef; diff --git a/vm/src/obj/objproperty.rs b/vm/src/obj/objproperty.rs index 5351968606..9a4d9fa05c 100644 --- a/vm/src/obj/objproperty.rs +++ b/vm/src/obj/objproperty.rs @@ -6,8 +6,7 @@ use std::sync::RwLock; use super::objtype::PyClassRef; use crate::function::OptionalArg; use crate::pyobject::{ - IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TypeProtocol, + IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; use crate::slots::SlotDescriptor; use crate::vm::VirtualMachine; @@ -52,7 +51,6 @@ pub struct PyProperty { deleter: Option, doc: RwLock>, } -impl ThreadSafe for PyProperty {} impl PyValue for PyProperty { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 68bc41dcd6..ef93ee230b 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -12,8 +12,7 @@ use super::objtype::PyClassRef; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyhash; use crate::pyobject::{ - PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, - TypeProtocol, + PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -32,7 +31,6 @@ pub struct PyRange { pub stop: PyIntRef, pub step: PyIntRef, } -impl ThreadSafe for PyRange {} impl PyValue for PyRange { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -405,7 +403,6 @@ pub struct PyRangeIterator { position: AtomicCell, range: PyRangeRef, } -impl ThreadSafe for PyRangeIterator {} impl PyValue for PyRangeIterator { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index f467241177..0c566afdbc 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -9,8 +9,8 @@ use crate::dictdatatype; use crate::function::{Args, OptionalArg}; use crate::pyhash; use crate::pyobject::{ - PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TryFromObject, TypeProtocol, + PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::vm::{ReprGuard, VirtualMachine}; @@ -26,7 +26,6 @@ pub struct PySet { inner: PySetInner, } pub type PySetRef = PyRef; -impl ThreadSafe for PySet {} /// frozenset() -> empty frozenset object /// frozenset(iterable) -> frozenset object @@ -38,7 +37,6 @@ pub struct PyFrozenSet { inner: PySetInner, } pub type PyFrozenSetRef = PyRef; -impl ThreadSafe for PyFrozenSet {} impl fmt::Debug for PySet { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs index 198b740ca4..ff11cc17d1 100644 --- a/vm/src/obj/objslice.rs +++ b/vm/src/obj/objslice.rs @@ -2,8 +2,7 @@ use super::objint::PyInt; use super::objtype::PyClassRef; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TryIntoRef, + IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, }; use crate::vm::VirtualMachine; use num_bigint::{BigInt, ToBigInt}; @@ -16,7 +15,6 @@ pub struct PySlice { pub stop: PyObjectRef, pub step: Option, } -impl ThreadSafe for PySlice {} impl PyValue for PySlice { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objstaticmethod.rs b/vm/src/obj/objstaticmethod.rs index 7b3f1241a3..f78842b2cd 100644 --- a/vm/src/obj/objstaticmethod.rs +++ b/vm/src/obj/objstaticmethod.rs @@ -1,6 +1,6 @@ use super::objtype::PyClassRef; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::slots::SlotDescriptor; use crate::vm::VirtualMachine; @@ -10,7 +10,6 @@ pub struct PyStaticMethod { pub callable: PyObjectRef, } pub type PyStaticMethodRef = PyRef; -impl ThreadSafe for PyStaticMethod {} impl PyValue for PyStaticMethod { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 0b1747affd..cc50eb898f 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -31,7 +31,7 @@ use crate::function::{OptionalArg, OptionalOption, PyFuncArgs}; use crate::pyhash; use crate::pyobject::{ Either, IdProtocol, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyIterable, - PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, TryIntoRef, TypeProtocol, + PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TryIntoRef, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -52,7 +52,6 @@ pub struct PyString { hash: AtomicCell>, len: AtomicCell>, } -impl ThreadSafe for PyString {} impl PyString { #[inline] @@ -103,7 +102,6 @@ pub struct PyStringIterator { pub string: PyStringRef, position: AtomicCell, } -impl ThreadSafe for PyStringIterator {} impl PyValue for PyStringIterator { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -137,7 +135,6 @@ pub struct PyStringReverseIterator { pub position: AtomicCell, pub string: PyStringRef, } -impl ThreadSafe for PyStringReverseIterator {} impl PyValue for PyStringReverseIterator { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -532,6 +529,36 @@ impl PyString { ) } + /// removeprefix($self, prefix, /) + /// + /// + /// Return a str with the given prefix string removed if present. + /// + /// If the string starts with the prefix string, return string[len(prefix):] + /// Otherwise, return a copy of the original string. + #[pymethod] + fn removeprefix(&self, pref: PyStringRef) -> String { + self.value + .as_str() + .py_removeprefix(&pref.value, pref.value.len(), |s, p| s.starts_with(p)) + .to_string() + } + + /// removesuffix(self, prefix, /) + /// + /// + /// Return a str with the given suffix string removed if present. + /// + /// If the string ends with the suffix string, return string[:len(suffix)] + /// Otherwise, return a copy of the original string. + #[pymethod] + fn removesuffix(&self, suff: PyStringRef) -> String { + self.value + .as_str() + .py_removesuffix(&suff.value, suff.value.len(), |s, p| s.ends_with(p)) + .to_string() + } + #[pymethod] fn removeprefix(&self, pref: PyStringRef) -> PyResult { if self.value.as_str().starts_with(&pref.value) { diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index e7dabdb57e..e9d0bd364c 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -10,8 +10,8 @@ use super::objstr::PyStringRef; use super::objtype::{self, PyClass, PyClassRef}; use crate::function::OptionalArg; use crate::pyobject::{ - IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, - TryFromObject, TypeProtocol, + IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::scope::NameProtocol; use crate::slots::SlotDescriptor; @@ -27,7 +27,6 @@ pub struct PySuper { typ: PyClassRef, obj: Option<(PyObjectRef, PyClassRef)>, } -impl ThreadSafe for PySuper {} impl PyValue for PySuper { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/objtraceback.rs b/vm/src/obj/objtraceback.rs index f51f6d3b02..ff49abbf19 100644 --- a/vm/src/obj/objtraceback.rs +++ b/vm/src/obj/objtraceback.rs @@ -1,6 +1,6 @@ use crate::frame::FrameRef; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{PyClassImpl, PyContext, PyRef, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyRef, PyValue}; use crate::vm::VirtualMachine; #[pyclass] @@ -11,7 +11,6 @@ pub struct PyTraceback { pub lasti: usize, pub lineno: usize, } -impl ThreadSafe for PyTraceback {} pub type PyTracebackRef = PyRef; diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index 85427a1898..1b30994735 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -9,7 +9,7 @@ use crate::pyhash; use crate::pyobject::{ IntoPyObject, PyArithmaticValue::{self, *}, - PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, + PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, }; use crate::sequence; use crate::vm::{ReprGuard, VirtualMachine}; @@ -23,8 +23,6 @@ pub struct PyTuple { elements: Vec, } -impl ThreadSafe for PyTuple {} - impl fmt::Debug for PyTuple { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // TODO: implement more informational, non-recursive Debug formatter @@ -249,8 +247,6 @@ pub struct PyTupleIterator { tuple: PyTupleRef, } -impl ThreadSafe for PyTupleIterator {} - impl PyValue for PyTupleIterator { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.tupleiterator_type() diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index 07de6d8a3c..85b2b4a3ff 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -12,7 +12,7 @@ use super::objweakref::PyWeak; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ IdProtocol, PyAttributes, PyClassImpl, PyContext, PyIterable, PyObject, PyObjectRef, PyRef, - PyResult, PyValue, ThreadSafe, TypeProtocol, + PyResult, PyValue, TypeProtocol, }; use crate::slots::{PyClassSlots, PyTpFlags}; use crate::vm::VirtualMachine; @@ -32,8 +32,6 @@ pub struct PyClass { pub slots: RwLock, } -impl ThreadSafe for PyClass {} - impl fmt::Display for PyClass { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.name, f) diff --git a/vm/src/obj/objweakproxy.rs b/vm/src/obj/objweakproxy.rs index f147f3decc..d10f9fcc4a 100644 --- a/vm/src/obj/objweakproxy.rs +++ b/vm/src/obj/objweakproxy.rs @@ -1,7 +1,7 @@ use super::objtype::PyClassRef; use super::objweakref::PyWeak; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; #[pyclass] @@ -10,8 +10,6 @@ pub struct PyWeakProxy { weak: PyWeak, } -impl ThreadSafe for PyWeakProxy {} - impl PyValue for PyWeakProxy { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.weakproxy_type() diff --git a/vm/src/obj/objweakref.rs b/vm/src/obj/objweakref.rs index bf3d7c0810..64f43a4eec 100644 --- a/vm/src/obj/objweakref.rs +++ b/vm/src/obj/objweakref.rs @@ -1,26 +1,28 @@ use super::objtype::PyClassRef; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - PyClassImpl, PyContext, PyObject, PyObjectPayload, PyObjectRef, PyRef, PyResult, PyValue, - ThreadSafe, + IdProtocol, PyClassImpl, PyContext, PyObject, PyObjectPayload, PyObjectRef, PyRef, PyResult, + PyValue, TypeProtocol, }; use crate::slots::SlotCall; use crate::vm::VirtualMachine; +use crate::pyhash::PyHash; +use crossbeam_utils::atomic::AtomicCell; use std::sync::{Arc, Weak}; #[pyclass] #[derive(Debug)] pub struct PyWeak { referent: Weak>, + hash: AtomicCell>, } -impl ThreadSafe for PyWeak {} - impl PyWeak { pub fn downgrade(obj: &PyObjectRef) -> PyWeak { PyWeak { referent: Arc::downgrade(obj), + hash: AtomicCell::new(None), } } @@ -56,6 +58,48 @@ impl PyWeak { ) -> PyResult> { PyWeak::downgrade(&referent).into_ref_with_type(vm, cls) } + + #[pymethod(magic)] + fn hash(&self, vm: &VirtualMachine) -> PyResult { + match self.hash.load() { + Some(hash) => Ok(hash), + None => { + let obj = self + .upgrade() + .ok_or_else(|| vm.new_type_error("weak object has gone away".to_owned()))?; + let hash = vm._hash(&obj)?; + self.hash.store(Some(hash)); + Ok(hash) + } + } + } + + #[pymethod(magic)] + fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(other) = other.payload_if_subclass::(vm) { + self.upgrade() + .and_then(|s| other.upgrade().map(|o| (s, o))) + .map_or(Ok(false), |(a, b)| vm.bool_eq(a, b)) + .map(|b| vm.new_bool(b)) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + #[pymethod(magic)] + fn repr(zelf: PyRef) -> String { + let id = zelf.get_id(); + if let Some(o) = zelf.upgrade() { + format!( + "", + id, + o.class().name, + o.get_id(), + ) + } else { + format!("", id) + } + } } pub fn init(context: &PyContext) { diff --git a/vm/src/obj/objzip.rs b/vm/src/obj/objzip.rs index 9421b0d72e..751a9731f7 100644 --- a/vm/src/obj/objzip.rs +++ b/vm/src/obj/objzip.rs @@ -1,7 +1,7 @@ use super::objiter; use super::objtype::PyClassRef; use crate::function::Args; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; pub type PyZipRef = PyRef; @@ -11,7 +11,6 @@ pub type PyZipRef = PyRef; pub struct PyZip { iterators: Vec, } -impl ThreadSafe for PyZip {} impl PyValue for PyZip { fn class(vm: &VirtualMachine) -> PyClassRef { diff --git a/vm/src/obj/pystr.rs b/vm/src/obj/pystr.rs index 43a0260128..9f262dfe7f 100644 --- a/vm/src/obj/pystr.rs +++ b/vm/src/obj/pystr.rs @@ -264,4 +264,37 @@ pub trait PyCommonString { fn py_rjust(&self, width: usize, fillchar: E) -> Self::Container { self.py_pad(width - self.chars_len(), 0, fillchar) } + + fn py_removeprefix( + &self, + prefix: &Self::Container, + prefix_len: usize, + is_prefix: FC, + ) -> &Self + where + FC: Fn(&Self, &Self::Container) -> bool, + { + //if self.py_starts_with(prefix) { + if is_prefix(&self, &prefix) { + self.get_bytes(prefix_len..self.bytes_len()) + } else { + &self + } + } + + fn py_removesuffix( + &self, + suffix: &Self::Container, + suffix_len: usize, + is_suffix: FC, + ) -> &Self + where + FC: Fn(&Self, &Self::Container) -> bool, + { + if is_suffix(&self, &suffix) { + self.get_bytes(0..self.bytes_len() - suffix_len) + } else { + &self + } + } } diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 2ce4bd08f1..1408020888 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1193,7 +1193,7 @@ impl PyObject { } } -pub trait PyValue: fmt::Debug + Sized + 'static { +pub trait PyValue: fmt::Debug + Send + Sync + Sized + 'static { const HAVE_DICT: bool = false; fn class(vm: &VirtualMachine) -> PyClassRef; @@ -1223,14 +1223,7 @@ pub trait PyValue: fmt::Debug + Sized + 'static { } } -// Temporary trait to follow the progress of threading conversion -pub trait ThreadSafe: Send + Sync {} -// Temporary hack to help with converting object that contain PyObjectRef to ThreadSafe. -// Should be removed before threading is allowed. Do not try this at home!!! -unsafe impl Send for PyObject {} -unsafe impl Sync for PyObject {} - -pub trait PyObjectPayload: Any + fmt::Debug + 'static { +pub trait PyObjectPayload: Any + fmt::Debug + Send + Sync + 'static { fn as_any(&self) -> &dyn Any; } diff --git a/vm/src/stdlib/array.rs b/vm/src/stdlib/array.rs index b4a437477e..9071ee62fb 100644 --- a/vm/src/stdlib/array.rs +++ b/vm/src/stdlib/array.rs @@ -6,14 +6,15 @@ use crate::obj::objtype::PyClassRef; use crate::obj::{objbool, objiter}; use crate::pyobject::{ Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, - ThreadSafe, TryFromObject, + TryFromObject, }; use crate::VirtualMachine; -use std::cell::Cell; use std::fmt; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use crossbeam_utils::atomic::AtomicCell; + struct ArrayTypeSpecifierError { _priv: (), } @@ -227,8 +228,6 @@ pub struct PyArray { array: RwLock, } -impl ThreadSafe for PyArray {} - pub type PyArrayRef = PyRef; impl PyValue for PyArray { @@ -421,7 +420,7 @@ impl PyArray { #[pymethod(name = "__iter__")] fn iter(zelf: PyRef) -> PyArrayIter { PyArrayIter { - position: Cell::new(0), + position: AtomicCell::new(0), array: zelf, } } @@ -430,7 +429,7 @@ impl PyArray { #[pyclass] #[derive(Debug)] pub struct PyArrayIter { - position: Cell, + position: AtomicCell, array: PyArrayRef, } @@ -444,14 +443,9 @@ impl PyValue for PyArrayIter { impl PyArrayIter { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.array.borrow_value().len() { - let ret = self - .array - .borrow_value() - .getitem_by_idx(self.position.get(), vm) - .unwrap()?; - self.position.set(self.position.get() + 1); - Ok(ret) + let pos = self.position.fetch_add(1); + if let Some(item) = self.array.borrow_value().getitem_by_idx(pos, vm) { + Ok(item?) } else { Err(objiter::new_stop_iteration(vm)) } diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 4ec3633b5c..8cd32d9498 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -12,14 +12,16 @@ mod _collections { use crate::vm::ReprGuard; use crate::VirtualMachine; use itertools::Itertools; - use std::cell::{Cell, RefCell}; use std::collections::VecDeque; + use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + + use crossbeam_utils::atomic::AtomicCell; #[pyclass(name = "deque")] - #[derive(Debug, Clone)] + #[derive(Debug)] struct PyDeque { - deque: RefCell>, - maxlen: Cell>, + deque: RwLock>, + maxlen: AtomicCell>, } type PyDequeRef = PyRef; @@ -36,8 +38,12 @@ mod _collections { } impl PyDeque { - fn borrow_deque<'a>(&'a self) -> impl std::ops::Deref> + 'a { - self.deque.borrow() + fn borrow_deque(&self) -> RwLockReadGuard<'_, VecDeque> { + self.deque.read().unwrap() + } + + fn borrow_deque_mut(&self) -> RwLockWriteGuard<'_, VecDeque> { + self.deque.write().unwrap() } } @@ -51,8 +57,8 @@ mod _collections { vm: &VirtualMachine, ) -> PyResult> { let py_deque = PyDeque { - deque: RefCell::default(), - maxlen: maxlen.into(), + deque: RwLock::default(), + maxlen: AtomicCell::new(maxlen), }; if let OptionalArg::Present(iter) = iter { py_deque.extend(iter, vm)?; @@ -62,8 +68,8 @@ mod _collections { #[pymethod] fn append(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { + let mut deque = self.borrow_deque_mut(); + if self.maxlen.load() == Some(deque.len()) { deque.pop_front(); } deque.push_back(obj); @@ -71,8 +77,8 @@ mod _collections { #[pymethod] fn appendleft(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { + let mut deque = self.borrow_deque_mut(); + if self.maxlen.load() == Some(deque.len()) { deque.pop_back(); } deque.push_front(obj); @@ -80,18 +86,21 @@ mod _collections { #[pymethod] fn clear(&self) { - self.deque.borrow_mut().clear() + self.borrow_deque_mut().clear() } #[pymethod] fn copy(&self) -> Self { - self.clone() + PyDeque { + deque: RwLock::new(self.borrow_deque().clone()), + maxlen: AtomicCell::new(self.maxlen.load()), + } } #[pymethod] fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut count = 0; - for elem in self.deque.borrow().iter() { + for elem in self.borrow_deque().iter() { if vm.identical_or_equal(elem, &obj)? { count += 1; } @@ -124,7 +133,7 @@ mod _collections { stop: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - let deque = self.deque.borrow(); + let deque = self.borrow_deque(); let start = start.unwrap_or(0); let stop = stop.unwrap_or_else(|| deque.len()); for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() { @@ -141,9 +150,9 @@ mod _collections { #[pymethod] fn insert(&self, idx: i32, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); - if self.maxlen.get() == Some(deque.len()) { + if self.maxlen.load() == Some(deque.len()) { return Err(vm.new_index_error("deque already at its maximum size".to_owned())); } @@ -166,23 +175,21 @@ mod _collections { #[pymethod] fn pop(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() + self.borrow_deque_mut() .pop_back() .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) } #[pymethod] fn popleft(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() + self.borrow_deque_mut() .pop_front() .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) } #[pymethod] fn remove(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); let mut idx = None; for (i, elem) in deque.iter().enumerate() { if vm.identical_or_equal(elem, &obj)? { @@ -196,13 +203,13 @@ mod _collections { #[pymethod] fn reverse(&self) { - self.deque - .replace_with(|deque| deque.iter().cloned().rev().collect()); + let rev: VecDeque<_> = self.borrow_deque().iter().cloned().rev().collect(); + *self.borrow_deque_mut() = rev; } #[pymethod] fn rotate(&self, mid: OptionalArg) { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); let mid = mid.unwrap_or(1); if mid < 0 { deque.rotate_left(-mid as usize); @@ -213,26 +220,25 @@ mod _collections { #[pyproperty] fn maxlen(&self) -> Option { - self.maxlen.get() + self.maxlen.load() } #[pyproperty(setter)] fn set_maxlen(&self, maxlen: Option) { - self.maxlen.set(maxlen); + self.maxlen.store(maxlen); } #[pymethod(name = "__repr__")] fn repr(zelf: PyRef, vm: &VirtualMachine) -> PyResult { let repr = if let Some(_guard) = ReprGuard::enter(zelf.as_object()) { let elements = zelf - .deque - .borrow() + .borrow_deque() .iter() .map(|obj| vm.to_repr(obj)) .collect::, _>>()?; let maxlen = zelf .maxlen - .get() + .load() .map(|maxlen| format!(", maxlen={}", maxlen)) .unwrap_or_default(); format!("deque([{}]{})", elements.into_iter().format(", "), maxlen) @@ -336,29 +342,29 @@ mod _collections { #[pymethod(name = "__mul__")] fn mul(&self, n: isize) -> Self { - let deque: &VecDeque<_> = &self.deque.borrow(); + let deque: &VecDeque<_> = &self.borrow_deque(); let mul = sequence::seq_mul(deque, n); - let skipped = if let Some(maxlen) = self.maxlen.get() { + let skipped = if let Some(maxlen) = self.maxlen.load() { mul.len() - maxlen } else { 0 }; let deque = mul.skip(skipped).cloned().collect(); PyDeque { - deque: RefCell::new(deque), - maxlen: self.maxlen.clone(), + deque: RwLock::new(deque), + maxlen: AtomicCell::new(self.maxlen.load()), } } #[pymethod(name = "__len__")] fn len(&self) -> usize { - self.deque.borrow().len() + self.borrow_deque().len() } #[pymethod(name = "__iter__")] fn iter(zelf: PyRef) -> PyDequeIterator { PyDequeIterator { - position: Cell::new(0), + position: AtomicCell::new(0), deque: zelf, } } @@ -367,7 +373,7 @@ mod _collections { #[pyclass(name = "_deque_iterator")] #[derive(Debug)] struct PyDequeIterator { - position: Cell, + position: AtomicCell, deque: PyDequeRef, } @@ -381,9 +387,10 @@ mod _collections { impl PyDequeIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.deque.deque.borrow().len() { - let ret = self.deque.deque.borrow()[self.position.get()].clone(); - self.position.set(self.position.get() + 1); + let pos = self.position.fetch_add(1); + let deque = self.deque.borrow_deque(); + if pos < deque.len() { + let ret = deque[pos].clone(); Ok(ret) } else { Err(objiter::new_stop_iteration(vm)) diff --git a/vm/src/stdlib/csv.rs b/vm/src/stdlib/csv.rs index fc42643a2a..6880bf15ac 100644 --- a/vm/src/stdlib/csv.rs +++ b/vm/src/stdlib/csv.rs @@ -1,5 +1,5 @@ -use std::cell::RefCell; use std::fmt::{self, Debug, Formatter}; +use std::sync::RwLock; use csv as rust_csv; use itertools::join; @@ -126,7 +126,7 @@ impl ReadState { #[pyclass(name = "Reader")] struct Reader { - state: RefCell, + state: RwLock, } impl Debug for Reader { @@ -143,7 +143,7 @@ impl PyValue for Reader { impl Reader { fn new(iter: PyIterable, config: ReaderOption) -> Self { - let state = RefCell::new(ReadState::new(iter, config)); + let state = RwLock::new(ReadState::new(iter, config)); Reader { state } } } @@ -152,13 +152,13 @@ impl Reader { impl Reader { #[pymethod(name = "__iter__")] fn iter(this: PyRef, vm: &VirtualMachine) -> PyResult { - this.state.borrow_mut().cast_to_reader(vm)?; + this.state.write().unwrap().cast_to_reader(vm)?; this.into_pyobject(vm) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let mut state = self.state.borrow_mut(); + let mut state = self.state.write().unwrap(); state.cast_to_reader(vm)?; if let ReadState::CsvIter(ref mut reader) = &mut *state { diff --git a/vm/src/stdlib/errno.rs b/vm/src/stdlib/errno.rs index 6113b48c22..28934a1bd3 100644 --- a/vm/src/stdlib/errno.rs +++ b/vm/src/stdlib/errno.rs @@ -88,7 +88,7 @@ const ERROR_CODES: &[(&str, i32)] = &[ e!(ENODEV), e!(EHOSTUNREACH), e!(cfg(not(windows)), ENOMSG), - e!(cfg(not(windows)), ENODATA), + e!(cfg(not(any(target_os = "openbsd", windows))), ENODATA), e!(cfg(not(windows)), ENOTBLK), e!(ENOSYS), e!(EPIPE), @@ -115,7 +115,7 @@ const ERROR_CODES: &[(&str, i32)] = &[ e!(EISCONN), e!(ESHUTDOWN), e!(EBADF), - e!(cfg(not(windows)), EMULTIHOP), + e!(cfg(not(any(target_os = "openbsd", windows))), EMULTIHOP), e!(EIO), e!(EPROTOTYPE), e!(ENOSPC), @@ -136,13 +136,13 @@ const ERROR_CODES: &[(&str, i32)] = &[ e!(cfg(not(windows)), EBADMSG), e!(ENFILE), e!(ESPIPE), - e!(cfg(not(windows)), ENOLINK), + e!(cfg(not(any(target_os = "openbsd", windows))), ENOLINK), e!(ENETRESET), e!(ETIMEDOUT), e!(ENOENT), e!(EEXIST), e!(EDQUOT), - e!(cfg(not(windows)), ENOSTR), + e!(cfg(not(any(target_os = "openbsd", windows))), ENOSTR), e!(EFAULT), e!(EFBIG), e!(ENOTCONN), @@ -151,7 +151,7 @@ const ERROR_CODES: &[(&str, i32)] = &[ e!(ECONNABORTED), e!(ENETUNREACH), e!(ESTALE), - e!(cfg(not(windows)), ENOSR), + e!(cfg(not(any(target_os = "openbsd", windows))), ENOSR), e!(ENOMEM), e!(ENOTSOCK), e!(EMLINK), @@ -162,7 +162,7 @@ const ERROR_CODES: &[(&str, i32)] = &[ e!(ENAMETOOLONG), e!(ENOTTY), e!(ESOCKTNOSUPPORT), - e!(cfg(not(windows)), ETIME), + e!(cfg(not(any(target_os = "openbsd", windows))), ETIME), e!(ETOOMANYREFS), e!(EMFILE), e!(cfg(not(windows)), ETXTBSY), diff --git a/vm/src/stdlib/hashlib.rs b/vm/src/stdlib/hashlib.rs index 88319a535c..f37426117d 100644 --- a/vm/src/stdlib/hashlib.rs +++ b/vm/src/stdlib/hashlib.rs @@ -2,7 +2,7 @@ use crate::function::{OptionalArg, PyFuncArgs}; use crate::obj::objbytes::{PyBytes, PyBytesRef}; use crate::obj::objstr::PyStringRef; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, ThreadSafe}; +use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use std::fmt; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; @@ -20,8 +20,6 @@ struct PyHasher { buffer: RwLock, } -impl ThreadSafe for PyHasher {} - impl fmt::Debug for PyHasher { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "hasher {}", self.name) diff --git a/vm/src/stdlib/imp.rs b/vm/src/stdlib/imp.rs index 43e8fa9321..92ec0bc2b2 100644 --- a/vm/src/stdlib/imp.rs +++ b/vm/src/stdlib/imp.rs @@ -26,11 +26,11 @@ fn imp_lock_held(_vm: &VirtualMachine) -> PyResult<()> { } fn imp_is_builtin(name: PyStringRef, vm: &VirtualMachine) -> bool { - vm.stdlib_inits.borrow().contains_key(name.as_str()) + vm.state.stdlib_inits.contains_key(name.as_str()) } fn imp_is_frozen(name: PyStringRef, vm: &VirtualMachine) -> bool { - vm.frozen.borrow().contains_key(name.as_str()) + vm.state.frozen.contains_key(name.as_str()) } fn imp_create_builtin(spec: PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -40,7 +40,7 @@ fn imp_create_builtin(spec: PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Ok(module) = sys_modules.get_item(name, vm) { Ok(module) - } else if let Some(make_module_func) = vm.stdlib_inits.borrow().get(name) { + } else if let Some(make_module_func) = vm.state.stdlib_inits.get(name) { Ok(make_module_func(vm)) } else { Ok(vm.get_none()) @@ -53,8 +53,8 @@ fn imp_exec_builtin(_mod: PyModuleRef) -> i32 { } fn imp_get_frozen_object(name: PyStringRef, vm: &VirtualMachine) -> PyResult { - vm.frozen - .borrow() + vm.state + .frozen .get(name.as_str()) .map(|frozen| { let mut frozen = frozen.code.clone(); @@ -71,8 +71,8 @@ fn imp_init_frozen(name: PyStringRef, vm: &VirtualMachine) -> PyResult { } fn imp_is_frozen_package(name: PyStringRef, vm: &VirtualMachine) -> PyResult { - vm.frozen - .borrow() + vm.state + .frozen .get(name.as_str()) .map(|frozen| frozen.package) .ok_or_else(|| { diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 3b2760b92d..6bf4b4cdd4 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -1,10 +1,11 @@ /* * I/O core tools. */ -use std::cell::{RefCell, RefMut}; use std::fs; use std::io::{self, prelude::*, Cursor, SeekFrom}; +use std::sync::{RwLock, RwLockWriteGuard}; +use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; use crate::exceptions::PyBaseExceptionRef; @@ -120,7 +121,8 @@ impl BufferedIO { #[derive(Debug)] struct PyStringIO { - buffer: RefCell>, + buffer: RwLock, + closed: AtomicCell, } type PyStringIORef = PyRef; @@ -132,10 +134,9 @@ impl PyValue for PyStringIO { } impl PyStringIORef { - fn buffer(&self, vm: &VirtualMachine) -> PyResult> { - let buffer = self.buffer.borrow_mut(); - if buffer.is_some() { - Ok(RefMut::map(buffer, |opt| opt.as_mut().unwrap())) + fn buffer(&self, vm: &VirtualMachine) -> PyResult> { + if !self.closed.load() { + Ok(self.buffer.write().unwrap()) } else { Err(vm.new_value_error("I/O operation on closed file.".to_owned())) } @@ -209,11 +210,11 @@ impl PyStringIORef { } fn closed(self) -> bool { - self.buffer.borrow().is_none() + self.closed.load() } fn close(self) { - self.buffer.replace(None); + self.closed.store(true); } } @@ -235,14 +236,16 @@ fn string_io_new( let input = flatten.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec()); PyStringIO { - buffer: RefCell::new(Some(BufferedIO::new(Cursor::new(input)))), + buffer: RwLock::new(BufferedIO::new(Cursor::new(input))), + closed: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } #[derive(Debug)] struct PyBytesIO { - buffer: RefCell>, + buffer: RwLock, + closed: AtomicCell, } type PyBytesIORef = PyRef; @@ -254,10 +257,9 @@ impl PyValue for PyBytesIO { } impl PyBytesIORef { - fn buffer(&self, vm: &VirtualMachine) -> PyResult> { - let buffer = self.buffer.borrow_mut(); - if buffer.is_some() { - Ok(RefMut::map(buffer, |opt| opt.as_mut().unwrap())) + fn buffer(&self, vm: &VirtualMachine) -> PyResult> { + if !self.closed.load() { + Ok(self.buffer.write().unwrap()) } else { Err(vm.new_value_error("I/O operation on closed file.".to_owned())) } @@ -320,11 +322,11 @@ impl PyBytesIORef { } fn closed(self) -> bool { - self.buffer.borrow().is_none() + self.closed.load() } fn close(self) { - self.buffer.replace(None); + self.closed.store(true) } } @@ -339,7 +341,8 @@ fn bytes_io_new( }; PyBytesIO { - buffer: RefCell::new(Some(BufferedIO::new(Cursor::new(raw_bytes)))), + buffer: RwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), + closed: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } @@ -519,6 +522,16 @@ fn buffered_io_base_fileno(instance: PyObjectRef, vm: &VirtualMachine) -> PyResu vm.call_method(&raw, "fileno", vec![]) } +fn buffered_io_base_mode(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let raw = vm.get_attribute(instance, "raw")?; + vm.get_attribute(raw, "mode") +} + +fn buffered_io_base_name(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let raw = vm.get_attribute(instance, "raw")?; + vm.get_attribute(raw, "name") +} + fn buffered_reader_read( instance: PyObjectRef, size: OptionalOption, @@ -589,6 +602,10 @@ mod fileio { opener: Option, } fn file_io_init(file_io: PyObjectRef, args: FileIOArgs, vm: &VirtualMachine) -> PyResult { + let mode = args + .mode + .map(|mode| mode.as_str().to_owned()) + .unwrap_or_else(|| "r".to_owned()); let (name, file_no) = match args.name { Either::A(name) => { if !args.closefd { @@ -596,10 +613,7 @@ mod fileio { vm.new_value_error("Cannot use closefd=False with file name".to_owned()) ); } - let mode = match args.mode { - Some(mode) => compute_c_flag(mode.as_str()), - None => libc::O_RDONLY as _, - }; + let mode = compute_c_flag(&mode); let fd = if let Some(opener) = args.opener { let fd = vm.invoke(&opener, vec![name.clone().into_object(), vm.new_int(mode)])?; @@ -626,6 +640,7 @@ mod fileio { }; vm.set_attr(&file_io, "name", name)?; + vm.set_attr(&file_io, "mode", vm.new_str(mode))?; vm.set_attr(&file_io, "__fileno", vm.new_int(file_no))?; vm.set_attr(&file_io, "closefd", vm.new_bool(args.closefd))?; vm.set_attr(&file_io, "__closed", vm.new_bool(false))?; @@ -826,6 +841,16 @@ fn text_io_wrapper_tell(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult vm.invoke(&vm.get_attribute(raw, "tell")?, vec![]) } +fn text_io_wrapper_mode(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let raw = vm.get_attribute(instance, "buffer")?; + vm.get_attribute(raw, "mode") +} + +fn text_io_wrapper_name(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let raw = vm.get_attribute(instance, "buffer")?; + vm.get_attribute(raw, "name") +} + fn text_io_wrapper_read( instance: PyObjectRef, size: OptionalOption, @@ -1055,6 +1080,8 @@ pub fn io_open( )), )?; + vm.set_attr(&file_io_obj, "mode", vm.new_str(mode_string.to_owned()))?; + // Create Buffered class to consume FileIO. The type of buffered class depends on // the operation in the mode. // There are 3 possible classes here, each inheriting from the RawBaseIO @@ -1137,6 +1164,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "tell" => ctx.new_method(buffered_reader_tell), "close" => ctx.new_method(buffered_reader_close), "fileno" => ctx.new_method(buffered_io_base_fileno), + "name" => ctx.new_readonly_getset("name", buffered_io_base_name), + "mode" => ctx.new_readonly_getset("mode", buffered_io_base_mode), }); let buffered_writer = py_class!(ctx, "BufferedWriter", buffered_io_base.clone(), { @@ -1146,7 +1175,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "__init__" => ctx.new_method(buffered_io_base_init), "write" => ctx.new_method(buffered_writer_write), "seekable" => ctx.new_method(buffered_writer_seekable), - "fileno" => ctx.new_method(buffered_io_base_fileno), + "name" => ctx.new_readonly_getset("name", buffered_io_base_name), + "mode" => ctx.new_readonly_getset("mode", buffered_io_base_mode), }); //TextIOBase Subclass @@ -1158,6 +1188,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "read" => ctx.new_method(text_io_wrapper_read), "write" => ctx.new_method(text_io_wrapper_write), "readline" => ctx.new_method(text_io_wrapper_readline), + "name" => ctx.new_readonly_getset("name", text_io_wrapper_name), + "mode" => ctx.new_readonly_getset("mode", text_io_wrapper_mode), }); //StringIO: in-memory text diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ea7767fb20..a9dafaaf70 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -2,11 +2,11 @@ pub(crate) use decl::make_module; #[pymodule(name = "itertools")] mod decl { + use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::{One, Signed, ToPrimitive, Zero}; - use std::cell::{Cell, RefCell}; use std::iter; - use std::rc::Rc; + use std::sync::{Arc, RwLock, RwLockWriteGuard}; use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs}; use crate::obj::objbool; @@ -23,7 +23,8 @@ mod decl { #[derive(Debug)] struct PyItertoolsChain { iterables: Vec, - cur: RefCell<(usize, Option)>, + cur_idx: AtomicCell, + cached_iter: RwLock>, } impl PyValue for PyItertoolsChain { @@ -38,27 +39,38 @@ mod decl { fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult> { PyItertoolsChain { iterables: args.args, - cur: RefCell::new((0, None)), + cur_idx: AtomicCell::new(0), + cached_iter: RwLock::new(None), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let (ref mut cur_idx, ref mut cur_iter) = *self.cur.borrow_mut(); - while *cur_idx < self.iterables.len() { - if cur_iter.is_none() { - *cur_iter = Some(get_iter(vm, &self.iterables[*cur_idx])?); + loop { + let pos = self.cur_idx.load(); + if pos >= self.iterables.len() { + break; } + let cur_iter = if self.cached_iter.read().unwrap().is_none() { + // We need to call "get_iter" outside of the lock. + let iter = get_iter(vm, &self.iterables[pos])?; + *self.cached_iter.write().unwrap() = Some(iter.clone()); + iter + } else if let Some(cached_iter) = (*(self.cached_iter.read().unwrap())).clone() { + cached_iter + } else { + // Someone changed cached iter to None since we checked. + continue; + }; - // can't be directly inside the 'match' clause, otherwise the borrows collide. - let obj = call_next(vm, cur_iter.as_ref().unwrap()); - match obj { + // We need to call "call_next" outside of the lock. + match call_next(vm, &cur_iter) { Ok(ok) => return Ok(ok), Err(err) => { if objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { - *cur_idx += 1; - *cur_iter = None; + self.cur_idx.fetch_add(1); + *self.cached_iter.write().unwrap() = None; } else { return Err(err); } @@ -85,7 +97,8 @@ mod decl { PyItertoolsChain { iterables, - cur: RefCell::new((0, None)), + cur_idx: AtomicCell::new(0), + cached_iter: RwLock::new(None), } .into_ref_with_type(vm, cls) } @@ -145,7 +158,7 @@ mod decl { #[pyclass(name = "count")] #[derive(Debug)] struct PyItertoolsCount { - cur: RefCell, + cur: RwLock, step: BigInt, } @@ -174,7 +187,7 @@ mod decl { }; PyItertoolsCount { - cur: RefCell::new(start), + cur: RwLock::new(start), step, } .into_ref_with_type(vm, cls) @@ -182,8 +195,9 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self) -> PyResult { - let result = self.cur.borrow().clone(); - *self.cur.borrow_mut() += &self.step; + let mut cur = self.cur.write().unwrap(); + let result = cur.clone(); + *cur += &self.step; Ok(PyInt::new(result)) } @@ -196,10 +210,9 @@ mod decl { #[pyclass(name = "cycle")] #[derive(Debug)] struct PyItertoolsCycle { - iter: RefCell, - saved: RefCell>, - index: Cell, - first_pass: Cell, + iter: PyObjectRef, + saved: RwLock>, + index: AtomicCell, } impl PyValue for PyItertoolsCycle { @@ -219,36 +232,31 @@ mod decl { let iter = get_iter(vm, &iterable)?; PyItertoolsCycle { - iter: RefCell::new(iter.clone()), - saved: RefCell::new(Vec::new()), - index: Cell::new(0), - first_pass: Cell::new(false), + iter: iter.clone(), + saved: RwLock::new(Vec::new()), + index: AtomicCell::new(0), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let item = if let Some(item) = get_next_object(vm, &self.iter.borrow())? { - if self.first_pass.get() { - return Ok(item); - } - - self.saved.borrow_mut().push(item.clone()); + let item = if let Some(item) = get_next_object(vm, &self.iter)? { + self.saved.write().unwrap().push(item.clone()); item } else { - if self.saved.borrow().len() == 0 { + let saved = self.saved.read().unwrap(); + if saved.len() == 0 { return Err(new_stop_iteration(vm)); } - let last_index = self.index.get(); - self.index.set(self.index.get() + 1); + let last_index = self.index.fetch_add(1); - if self.index.get() >= self.saved.borrow().len() { - self.index.set(0); + if last_index >= saved.len() - 1 { + self.index.store(0); } - self.saved.borrow()[last_index].clone() + saved[last_index].clone() }; Ok(item) @@ -264,7 +272,7 @@ mod decl { #[derive(Debug)] struct PyItertoolsRepeat { object: PyObjectRef, - times: Option>, + times: Option>, } impl PyValue for PyItertoolsRepeat { @@ -283,7 +291,7 @@ mod decl { vm: &VirtualMachine, ) -> PyResult> { let times = match times.into_option() { - Some(int) => Some(RefCell::new(int.as_bigint().clone())), + Some(int) => Some(RwLock::new(int.as_bigint().clone())), None => None, }; @@ -297,10 +305,11 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { if let Some(ref times) = self.times { - if *times.borrow() <= BigInt::zero() { + let mut times = times.write().unwrap(); + if *times <= BigInt::zero() { return Err(new_stop_iteration(vm)); } - *times.borrow_mut() -= 1; + *times -= 1; } Ok(self.object.clone()) @@ -314,7 +323,7 @@ mod decl { #[pymethod(name = "__length_hint__")] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { match self.times { - Some(ref times) => vm.new_int(times.borrow().clone()), + Some(ref times) => vm.new_int(times.read().unwrap().clone()), None => vm.new_int(0), } } @@ -366,7 +375,7 @@ mod decl { struct PyItertoolsTakewhile { predicate: PyObjectRef, iterable: PyObjectRef, - stop_flag: RefCell, + stop_flag: AtomicCell, } impl PyValue for PyItertoolsTakewhile { @@ -389,14 +398,14 @@ mod decl { PyItertoolsTakewhile { predicate, iterable: iter, - stop_flag: RefCell::new(false), + stop_flag: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if *self.stop_flag.borrow() { + if self.stop_flag.load() { return Err(new_stop_iteration(vm)); } @@ -409,7 +418,7 @@ mod decl { if verdict { Ok(obj) } else { - *self.stop_flag.borrow_mut() = true; + self.stop_flag.store(true); Err(new_stop_iteration(vm)) } } @@ -425,7 +434,7 @@ mod decl { struct PyItertoolsDropwhile { predicate: PyCallable, iterable: PyObjectRef, - start_flag: Cell, + start_flag: AtomicCell, } impl PyValue for PyItertoolsDropwhile { @@ -448,7 +457,7 @@ mod decl { PyItertoolsDropwhile { predicate, iterable: iter, - start_flag: Cell::new(false), + start_flag: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } @@ -458,13 +467,13 @@ mod decl { let predicate = &self.predicate; let iterable = &self.iterable; - if !self.start_flag.get() { + if !self.start_flag.load() { loop { let obj = call_next(vm, iterable)?; let pred = predicate.clone(); let pred_value = vm.invoke(&pred.into_object(), vec![obj.clone()])?; if !objbool::boolval(vm, pred_value)? { - self.start_flag.set(true); + self.start_flag.store(true); return Ok(obj); } } @@ -482,8 +491,8 @@ mod decl { #[derive(Debug)] struct PyItertoolsIslice { iterable: PyObjectRef, - cur: RefCell, - next: RefCell, + cur: AtomicCell, + next: AtomicCell, stop: Option, step: usize, } @@ -567,8 +576,8 @@ mod decl { PyItertoolsIslice { iterable: iter, - cur: RefCell::new(0), - next: RefCell::new(start), + cur: AtomicCell::new(0), + next: AtomicCell::new(start), stop, step, } @@ -577,23 +586,23 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - while *self.cur.borrow() < *self.next.borrow() { + while self.cur.load() < self.next.load() { call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; + self.cur.fetch_add(1); } if let Some(stop) = self.stop { - if *self.cur.borrow() >= stop { + if self.cur.load() >= stop { return Err(new_stop_iteration(vm)); } } let obj = call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; + self.cur.fetch_add(1); // TODO is this overflow check required? attempts to copy CPython. - let (next, ovf) = (*self.next.borrow()).overflowing_add(self.step); - *self.next.borrow_mut() = if ovf { self.stop.unwrap() } else { next }; + let (next, ovf) = self.next.load().overflowing_add(self.step); + self.next.store(if ovf { self.stop.unwrap() } else { next }); Ok(obj) } @@ -665,7 +674,7 @@ mod decl { struct PyItertoolsAccumulate { iterable: PyObjectRef, binop: PyObjectRef, - acc_value: RefCell>, + acc_value: RwLock>, } impl PyValue for PyItertoolsAccumulate { @@ -688,7 +697,7 @@ mod decl { PyItertoolsAccumulate { iterable: iter, binop: binop.unwrap_or_else(|| vm.get_none()), - acc_value: RefCell::from(Option::None), + acc_value: RwLock::new(None), } .into_ref_with_type(vm, cls) } @@ -698,7 +707,9 @@ mod decl { let iterable = &self.iterable; let obj = call_next(vm, iterable)?; - let next_acc_value = match &*self.acc_value.borrow() { + let acc_value = self.acc_value.read().unwrap().clone(); + + let next_acc_value = match acc_value { None => obj.clone(), Some(value) => { if self.binop.is(&vm.get_none()) { @@ -708,7 +719,7 @@ mod decl { } } }; - self.acc_value.replace(Option::from(next_acc_value.clone())); + *self.acc_value.write().unwrap() = Some(next_acc_value.clone()); Ok(next_acc_value) } @@ -722,31 +733,31 @@ mod decl { #[derive(Debug)] struct PyItertoolsTeeData { iterable: PyObjectRef, - values: RefCell>, + values: RwLock>, } impl PyItertoolsTeeData { - fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - Ok(Rc::new(PyItertoolsTeeData { + fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult> { + Ok(Arc::new(PyItertoolsTeeData { iterable: get_iter(vm, &iterable)?, - values: RefCell::new(vec![]), + values: RwLock::new(vec![]), })) } fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { - if self.values.borrow().len() == index { + if self.values.read().unwrap().len() == index { let result = call_next(vm, &self.iterable)?; - self.values.borrow_mut().push(result); + self.values.write().unwrap().push(result); } - Ok(self.values.borrow()[index].clone()) + Ok(self.values.read().unwrap()[index].clone()) } } #[pyclass(name = "tee")] #[derive(Debug)] struct PyItertoolsTee { - tee_data: Rc, - index: Cell, + tee_data: Arc, + index: AtomicCell, } impl PyValue for PyItertoolsTee { @@ -764,7 +775,7 @@ mod decl { } Ok(PyItertoolsTee { tee_data: PyItertoolsTeeData::new(it, vm)?, - index: Cell::from(0), + index: AtomicCell::new(0), } .into_ref_with_type(vm, PyItertoolsTee::class(vm))? .into_object()) @@ -800,8 +811,8 @@ mod decl { #[pymethod(name = "__copy__")] fn copy(&self, vm: &VirtualMachine) -> PyResult { Ok(PyItertoolsTee { - tee_data: Rc::clone(&self.tee_data), - index: self.index.clone(), + tee_data: Arc::clone(&self.tee_data), + index: AtomicCell::new(self.index.load()), } .into_ref_with_type(vm, Self::class(vm))? .into_object()) @@ -809,8 +820,8 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let value = self.tee_data.get_item(vm, self.index.get())?; - self.index.set(self.index.get() + 1); + let value = self.tee_data.get_item(vm, self.index.load())?; + self.index.fetch_add(1); Ok(value) } @@ -824,9 +835,9 @@ mod decl { #[derive(Debug)] struct PyItertoolsProduct { pools: Vec>, - idxs: RefCell>, - cur: Cell, - stop: Cell, + idxs: RwLock>, + cur: AtomicCell, + stop: AtomicCell, } impl PyValue for PyItertoolsProduct { @@ -871,9 +882,9 @@ mod decl { PyItertoolsProduct { pools, - idxs: RefCell::new(vec![0; l]), - cur: Cell::new(l - 1), - stop: Cell::new(false), + idxs: RwLock::new(vec![0; l]), + cur: AtomicCell::new(l - 1), + stop: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } @@ -881,7 +892,7 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.stop.get() { + if self.stop.load() { return Err(new_stop_iteration(vm)); } @@ -893,41 +904,36 @@ mod decl { } } + let idxs = self.idxs.write().unwrap(); + let res = PyTuple::from( pools .iter() - .zip(self.idxs.borrow().iter()) + .zip(idxs.iter()) .map(|(pool, idx)| pool[*idx].clone()) .collect::>(), ); - self.update_idxs(); - - if self.is_end() { - self.stop.set(true); - } + self.update_idxs(idxs); Ok(res.into_ref(vm).into_object()) } - fn is_end(&self) -> bool { - let cur = self.cur.get(); - self.idxs.borrow()[cur] == self.pools[cur].len() - 1 && cur == 0 - } + fn update_idxs(&self, mut idxs: RwLockWriteGuard<'_, Vec>) { + let cur = self.cur.load(); + let lst_idx = &self.pools[cur].len() - 1; - fn update_idxs(&self) { - let lst_idx = &self.pools[self.cur.get()].len() - 1; - - if self.idxs.borrow()[self.cur.get()] == lst_idx { - if self.is_end() { + if idxs[cur] == lst_idx { + if cur == 0 { + self.stop.store(true); return; } - self.idxs.borrow_mut()[self.cur.get()] = 0; - self.cur.set(self.cur.get() - 1); - self.update_idxs(); + idxs[cur] = 0; + self.cur.fetch_sub(1); + self.update_idxs(idxs); } else { - self.idxs.borrow_mut()[self.cur.get()] += 1; - self.cur.set(self.idxs.borrow().len() - 1); + idxs[cur] += 1; + self.cur.store(idxs.len() - 1); } } @@ -941,9 +947,9 @@ mod decl { #[derive(Debug)] struct PyItertoolsCombinations { pool: Vec, - indices: RefCell>, - r: Cell, - exhausted: Cell, + indices: RwLock>, + r: AtomicCell, + exhausted: AtomicCell, } impl PyValue for PyItertoolsCombinations { @@ -974,9 +980,9 @@ mod decl { PyItertoolsCombinations { pool, - indices: RefCell::new((0..r).collect()), - r: Cell::new(r), - exhausted: Cell::new(r > n), + indices: RwLock::new((0..r).collect()), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(r > n), } .into_ref_with_type(vm, cls) } @@ -989,27 +995,28 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if r == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } let res = PyTuple::from( self.indices - .borrow() + .read() + .unwrap() .iter() .map(|&i| self.pool[i].clone()) .collect::>(), ); - let mut indices = self.indices.borrow_mut(); + let mut indices = self.indices.write().unwrap(); // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). let mut idx = r as isize - 1; @@ -1020,7 +1027,7 @@ mod decl { // If no suitable index is found, then the indices are all at // their maximum value and we're done. if idx < 0 { - self.exhausted.set(true); + self.exhausted.store(true); } else { // Increment the current index which we know is not at its // maximum. Then move back to the right setting each index @@ -1040,9 +1047,9 @@ mod decl { #[derive(Debug)] struct PyItertoolsCombinationsWithReplacement { pool: Vec, - indices: RefCell>, - r: Cell, - exhausted: Cell, + indices: RwLock>, + r: AtomicCell, + exhausted: AtomicCell, } impl PyValue for PyItertoolsCombinationsWithReplacement { @@ -1073,9 +1080,9 @@ mod decl { PyItertoolsCombinationsWithReplacement { pool, - indices: RefCell::new(vec![0; r]), - r: Cell::new(r), - exhausted: Cell::new(n == 0 && r > 0), + indices: RwLock::new(vec![0; r]), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(n == 0 && r > 0), } .into_ref_with_type(vm, cls) } @@ -1088,19 +1095,19 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if r == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } - let mut indices = self.indices.borrow_mut(); + let mut indices = self.indices.write().unwrap(); let res = vm .ctx @@ -1115,7 +1122,7 @@ mod decl { // If no suitable index is found, then the indices are all at // their maximum value and we're done. if idx < 0 { - self.exhausted.set(true); + self.exhausted.store(true); } else { let index = indices[idx as usize] + 1; @@ -1133,12 +1140,12 @@ mod decl { #[pyclass(name = "permutations")] #[derive(Debug)] struct PyItertoolsPermutations { - pool: Vec, // Collected input iterable - indices: RefCell>, // One index per element in pool - cycles: RefCell>, // One rollover counter per element in the result - result: RefCell>>, // Indexes of the most recently returned result - r: Cell, // Size of result tuple - exhausted: Cell, // Set when the iterator is exhausted + pool: Vec, // Collected input iterable + indices: RwLock>, // One index per element in pool + cycles: RwLock>, // One rollover counter per element in the result + result: RwLock>>, // Indexes of the most recently returned result + r: AtomicCell, // Size of result tuple + exhausted: AtomicCell, // Set when the iterator is exhausted } impl PyValue for PyItertoolsPermutations { @@ -1179,11 +1186,11 @@ mod decl { PyItertoolsPermutations { pool, - indices: RefCell::new((0..n).collect()), - cycles: RefCell::new((0..r).map(|i| n - i).collect()), - result: RefCell::new(None), - r: Cell::new(r), - exhausted: Cell::new(r > n), + indices: RwLock::new((0..n).collect()), + cycles: RwLock::new((0..r).map(|i| n - i).collect()), + result: RwLock::new(None), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(r > n), } .into_ref_with_type(vm, cls) } @@ -1196,23 +1203,23 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if n == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } - let result = &mut *self.result.borrow_mut(); + let mut result = self.result.write().unwrap(); - if let Some(ref mut result) = result { - let mut indices = self.indices.borrow_mut(); - let mut cycles = self.cycles.borrow_mut(); + if let Some(ref mut result) = *result { + let mut indices = self.indices.write().unwrap(); + let mut cycles = self.cycles.write().unwrap(); let mut sentinel = false; // Decrement rightmost cycle, moving leftward upon zero rollover @@ -1241,7 +1248,7 @@ mod decl { } } if !sentinel { - self.exhausted.set(true); + self.exhausted.store(true); return Err(new_stop_iteration(vm)); } } else { @@ -1265,7 +1272,6 @@ mod decl { struct PyItertoolsZipLongest { iterators: Vec, fillvalue: PyObjectRef, - numactive: Cell, } impl PyValue for PyItertoolsZipLongest { @@ -1299,12 +1305,9 @@ mod decl { .map(|iterable| get_iter(vm, &iterable)) .collect::, _>>()?; - let numactive = Cell::new(iterators.len()); - PyItertoolsZipLongest { iterators, fillvalue, - numactive, } .into_ref_with_type(vm, cls) } @@ -1315,7 +1318,7 @@ mod decl { Err(new_stop_iteration(vm)) } else { let mut result: Vec = Vec::new(); - let mut numactive = self.numactive.get(); + let mut numactive = self.iterators.len(); for idx in 0..self.iterators.len() { let next_obj = match call_next(vm, &self.iterators[idx]) { diff --git a/vm/src/stdlib/math.rs b/vm/src/stdlib/math.rs index 1ddbdf2893..67016efe17 100644 --- a/vm/src/stdlib/math.rs +++ b/vm/src/stdlib/math.rs @@ -9,7 +9,9 @@ use statrs::function::gamma::{gamma, ln_gamma}; use num_bigint::BigInt; use num_traits::{One, Zero}; -use crate::function::{OptionalArg, PyFuncArgs}; + +use crate::function::{Args, OptionalArg}; + use crate::obj::objfloat::{self, IntoPyFloat, PyFloatRef}; use crate::obj::objint::{self, PyInt, PyIntRef}; use crate::obj::objtype; @@ -272,55 +274,33 @@ fn math_ldexp( Ok(value * (2_f64).powf(objint::try_float(i.as_bigint(), vm)?)) } -fn math_perf_arb_len_int_op( - args: PyFuncArgs, - vm: &VirtualMachine, - op: F, - default: BigInt, -) -> PyResult +fn math_perf_arb_len_int_op(args: Args, op: F, default: BigInt) -> BigInt where F: Fn(&BigInt, &PyInt) -> BigInt, { - if !args.kwargs.is_empty() { - Err(vm.new_type_error("Takes no keyword arguments".to_owned())) - } else if args.args.is_empty() { - Ok(default) - } else if args.args.len() == 1 { - let a: PyObjectRef = args.args[0].clone(); - if let Some(aa) = a.payload_if_subclass::(vm) { - let res = op(aa.as_bigint(), aa); - Ok(res) - } else { - Err(vm.new_type_error("Only integer arguments are supported".to_owned())) - } - } else { - let a = args.args[0].clone(); - if let Some(aa) = a.payload_if_subclass::(vm) { - let mut res = aa.as_bigint().clone(); - for b in args.args[1..].iter() { - if let Some(bb) = b.payload_if_subclass::(vm) { - res = op(&res, bb); - } else { - return Err( - vm.new_type_error("Only integer arguments are supported".to_owned()) - ); - } - } - Ok(res) - } else { - Err(vm.new_type_error("Only integer arguments are supported".to_owned())) - } + let argvec = args.into_vec(); + + if argvec.is_empty() { + return default; + } else if argvec.len() == 1 { + return op(argvec[0].as_bigint(), &argvec[0]); + } + + let mut res = argvec[0].as_bigint().clone(); + for num in argvec[1..].iter() { + res = op(&res, &num) } + res } -fn math_gcd(args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { +fn math_gcd(args: Args) -> BigInt { use num_integer::Integer; - math_perf_arb_len_int_op(args, vm, |x, y| x.gcd(y.as_bigint()), BigInt::zero()) + math_perf_arb_len_int_op(args, |x, y| x.gcd(y.as_bigint()), BigInt::zero()) } -fn math_lcm(args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { +fn math_lcm(args: Args) -> BigInt { use num_integer::Integer; - math_perf_arb_len_int_op(args, vm, |x, y| x.lcm(y.as_bigint()), BigInt::one()) + math_perf_arb_len_int_op(args, |x, y| x.lcm(y.as_bigint()), BigInt::one()) } fn math_factorial(value: PyIntRef, vm: &VirtualMachine) -> PyResult { diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index 001e7eb1d4..2f830ed2c0 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -30,6 +30,7 @@ pub mod socket; mod string; #[cfg(feature = "rustpython-compiler")] mod symtable; +#[cfg(not(target_arch = "wasm32"))] mod thread; mod time_module; #[cfg(feature = "rustpython-parser")] @@ -65,7 +66,7 @@ mod winreg; #[cfg(not(any(target_arch = "wasm32", target_os = "redox")))] mod zlib; -pub type StdlibInitFunc = Box PyObjectRef>; +pub type StdlibInitFunc = Box PyObjectRef + Send + Sync>; pub fn get_module_inits() -> HashMap { #[allow(unused_mut)] @@ -89,7 +90,6 @@ pub fn get_module_inits() -> HashMap { "_random".to_owned() => Box::new(random::make_module), "_string".to_owned() => Box::new(string::make_module), "_struct".to_owned() => Box::new(pystruct::make_module), - "_thread".to_owned() => Box::new(thread::make_module), "time".to_owned() => Box::new(time_module::make_module), "_weakref".to_owned() => Box::new(weakref::make_module), "_imp".to_owned() => Box::new(imp::make_module), @@ -130,6 +130,7 @@ pub fn get_module_inits() -> HashMap { #[cfg(feature = "ssl")] modules.insert("_ssl".to_owned(), Box::new(ssl::make_module)); modules.insert("_subprocess".to_owned(), Box::new(subprocess::make_module)); + modules.insert("_thread".to_owned(), Box::new(thread::make_module)); #[cfg(not(target_os = "redox"))] modules.insert("zlib".to_owned(), Box::new(zlib::make_module)); modules.insert( diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index bd2a4e6536..644ba374fe 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1,4 +1,3 @@ -use std::cell::{Cell, RefCell}; use std::ffi; use std::fs::File; use std::fs::OpenOptions; @@ -7,10 +6,12 @@ use std::io::{self, ErrorKind, Read, Write}; use std::os::unix::fs::OpenOptionsExt; #[cfg(windows)] use std::os::windows::fs::OpenOptionsExt; +use std::sync::RwLock; use std::time::{Duration, SystemTime}; use std::{env, fs}; use bitflags::bitflags; +use crossbeam_utils::atomic::AtomicCell; #[cfg(unix)] use nix::errno::Errno; #[cfg(all(unix, not(target_os = "redox")))] @@ -393,7 +394,7 @@ fn getgroups() -> nix::Result> { }) } -#[cfg(any(target_os = "linux", target_os = "android"))] +#[cfg(any(target_os = "linux", target_os = "android", target_os = "openbsd"))] use nix::unistd::getgroups; #[cfg(target_os = "redox")] @@ -683,8 +684,8 @@ impl DirEntryRef { #[pyclass] #[derive(Debug)] struct ScandirIterator { - entries: RefCell, - exhausted: Cell, + entries: RwLock, + exhausted: AtomicCell, mode: OutputMode, } @@ -698,11 +699,11 @@ impl PyValue for ScandirIterator { impl ScandirIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.exhausted.get() { + if self.exhausted.load() { return Err(objiter::new_stop_iteration(vm)); } - match self.entries.borrow_mut().next() { + match self.entries.write().unwrap().next() { Some(entry) => match entry { Ok(entry) => Ok(DirEntry { entry, @@ -713,7 +714,7 @@ impl ScandirIterator { Err(s) => Err(convert_io_error(vm, s)), }, None => { - self.exhausted.set(true); + self.exhausted.store(true); Err(objiter::new_stop_iteration(vm)) } } @@ -721,7 +722,7 @@ impl ScandirIterator { #[pymethod] fn close(&self) { - self.exhausted.set(true); + self.exhausted.store(true); } #[pymethod(name = "__iter__")] @@ -748,8 +749,8 @@ fn os_scandir(path: OptionalArg, vm: &VirtualMachine) -> PyResult { match fs::read_dir(path.path) { Ok(iter) => Ok(ScandirIterator { - entries: RefCell::new(iter), - exhausted: Cell::new(false), + entries: RwLock::new(iter), + exhausted: AtomicCell::new(false), mode: path.mode, } .into_ref(vm) @@ -812,6 +813,8 @@ fn os_stat( use std::os::linux::fs::MetadataExt; #[cfg(target_os = "macos")] use std::os::macos::fs::MetadataExt; + #[cfg(target_os = "openbsd")] + use std::os::openbsd::fs::MetadataExt; #[cfg(target_os = "redox")] use std::os::redox::fs::MetadataExt; @@ -918,7 +921,8 @@ fn os_stat( target_os = "macos", target_os = "android", target_os = "redox", - windows + windows, + unix )))] fn os_stat( file: Either, @@ -1307,13 +1311,13 @@ fn os_urandom(size: usize, vm: &VirtualMachine) -> PyResult> { } } -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "openbsd"))] type ModeT = u32; #[cfg(target_os = "macos")] type ModeT = u16; -#[cfg(any(target_os = "macos", target_os = "linux"))] +#[cfg(any(target_os = "macos", target_os = "linux", target_os = "openbsd"))] fn os_umask(mask: ModeT, _vm: &VirtualMachine) -> PyResult { let ret_mask = unsafe { libc::umask(mask) }; Ok(ret_mask) @@ -1457,6 +1461,73 @@ fn os_sync(_vm: &VirtualMachine) -> PyResult<()> { Ok(()) } +// cfg from nix +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +fn os_getresuid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> { + let mut ruid = 0; + let mut euid = 0; + let mut suid = 0; + let ret = unsafe { libc::getresuid(&mut ruid, &mut euid, &mut suid) }; + if ret == 0 { + Ok((ruid, euid, suid)) + } else { + Err(errno_err(vm)) + } +} + +// cfg from nix +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +fn os_getresgid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> { + let mut rgid = 0; + let mut egid = 0; + let mut sgid = 0; + let ret = unsafe { libc::getresgid(&mut rgid, &mut egid, &mut sgid) }; + if ret == 0 { + Ok((rgid, egid, sgid)) + } else { + Err(errno_err(vm)) + } +} + +// cfg from nix +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +fn os_setregid(rgid: u32, egid: u32, vm: &VirtualMachine) -> PyResult { + let ret = unsafe { libc::setregid(rgid, egid) }; + if ret == 0 { + Ok(0) + } else { + Err(errno_err(vm)) + } +} + +// cfg from nix +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +fn os_initgroups(user_name: PyStringRef, gid: u32, vm: &VirtualMachine) -> PyResult<()> { + let user = ffi::CString::new(user_name.as_str()).unwrap(); + let gid = Gid::from_raw(gid); + unistd::initgroups(&user, gid).map_err(|err| convert_nix_error(vm, err)) +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -1688,6 +1759,10 @@ fn extend_module_platform_specific(vm: &VirtualMachine, module: &PyObjectRef) { ))] extend_module!(vm, module, { "setresuid" => ctx.new_function(os_setresuid), + "getresuid" => ctx.new_function(os_getresuid), + "getresgid" => ctx.new_function(os_getresgid), + "setregid" => ctx.new_function(os_setregid), + "initgroups" => ctx.new_function(os_initgroups), }); // cfg taken from nix diff --git a/vm/src/stdlib/random.rs b/vm/src/stdlib/random.rs index 63144451cc..acb8f9869e 100644 --- a/vm/src/stdlib/random.rs +++ b/vm/src/stdlib/random.rs @@ -11,18 +11,19 @@ mod _random { use crate::VirtualMachine; use num_bigint::{BigInt, Sign}; use num_traits::Signed; - use rand::RngCore; - use std::cell::RefCell; + use rand::{rngs::StdRng, RngCore, SeedableRng}; + + use std::sync::Mutex; #[derive(Debug)] enum PyRng { - Std(rand::rngs::ThreadRng), + Std(Box), MT(Box), } impl Default for PyRng { fn default() -> Self { - PyRng::Std(rand::thread_rng()) + PyRng::Std(Box::new(StdRng::from_entropy())) } } @@ -56,7 +57,7 @@ mod _random { #[pyclass(name = "Random")] #[derive(Debug)] struct PyRandom { - rng: RefCell, + rng: Mutex, } impl PyValue for PyRandom { @@ -70,14 +71,15 @@ mod _random { #[pyslot(new)] fn new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult> { PyRandom { - rng: RefCell::new(PyRng::default()), + rng: Mutex::default(), } .into_ref_with_type(vm, cls) } #[pymethod] fn random(&self) -> f64 { - mt19937::gen_res53(&mut *self.rng.borrow_mut()) + let mut rng = self.rng.lock().unwrap(); + mt19937::gen_res53(&mut *rng) } #[pymethod] @@ -93,13 +95,13 @@ mod _random { } }; - *self.rng.borrow_mut() = new_rng; + *self.rng.lock().unwrap() = new_rng; } #[pymethod] - fn getrandbits(&self, mut k: usize) -> BigInt { - let mut rng = self.rng.borrow_mut(); - + fn getrandbits(&self, k: usize) -> BigInt { + let mut rng = self.rng.lock().unwrap(); + let mut k = k; let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32; if k <= 32 { diff --git a/vm/src/stdlib/signal.rs b/vm/src/stdlib/signal.rs index b5c9baabfe..b3471d62c0 100644 --- a/vm/src/stdlib/signal.rs +++ b/vm/src/stdlib/signal.rs @@ -38,6 +38,10 @@ fn assert_in_range(signum: i32, vm: &VirtualMachine) -> PyResult<()> { fn signal(signalnum: i32, handler: PyObjectRef, vm: &VirtualMachine) -> PyResult { assert_in_range(signalnum, vm)?; + let signal_handlers = vm + .signal_handlers + .as_ref() + .ok_or_else(|| vm.new_value_error("signal only works in main thread".to_owned()))?; let sig_handler = match usize::try_from_object(vm, handler.clone()).ok() { Some(SIG_DFL) => SIG_DFL, @@ -68,7 +72,7 @@ fn signal(signalnum: i32, handler: PyObjectRef, vm: &VirtualMachine) -> PyResult let mut old_handler = handler; std::mem::swap( - &mut vm.signal_handlers.borrow_mut()[signalnum as usize], + &mut signal_handlers.borrow_mut()[signalnum as usize], &mut old_handler, ); Ok(old_handler) @@ -76,7 +80,11 @@ fn signal(signalnum: i32, handler: PyObjectRef, vm: &VirtualMachine) -> PyResult fn getsignal(signalnum: i32, vm: &VirtualMachine) -> PyResult { assert_in_range(signalnum, vm)?; - Ok(vm.signal_handlers.borrow()[signalnum as usize].clone()) + let signal_handlers = vm + .signal_handlers + .as_ref() + .ok_or_else(|| vm.new_value_error("getsignal only works in main thread".to_owned()))?; + Ok(signal_handlers.borrow()[signalnum as usize].clone()) } #[cfg(unix)] @@ -91,13 +99,18 @@ fn alarm(time: u32) -> u32 { #[cfg_attr(feature = "flame-it", flame)] pub fn check_signals(vm: &VirtualMachine) -> PyResult<()> { + let signal_handlers = match vm.signal_handlers { + Some(ref h) => h.borrow(), + None => return Ok(()), + }; + if !ANY_TRIGGERED.swap(false, Ordering::Relaxed) { return Ok(()); } for (signum, trigger) in TRIGGERS.iter().enumerate().skip(1) { let triggerd = trigger.swap(false, Ordering::Relaxed); if triggerd { - let handler = &vm.signal_handlers.borrow()[signum]; + let handler = &signal_handlers[signum]; if vm.is_callable(handler) { vm.invoke(handler, vec![vm.new_int(signum), vm.get_none()])?; } @@ -145,7 +158,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { } else { vm.get_none() }; - vm.signal_handlers.borrow_mut()[signum] = py_handler; + vm.signal_handlers.as_ref().unwrap().borrow_mut()[signum] = py_handler; } signal(libc::SIGINT, int_handler, vm).expect("Failed to set sigint handler"); @@ -184,7 +197,7 @@ fn extend_module_platform_specific(vm: &VirtualMachine, module: &PyObjectRef) { "SIGSYS" => ctx.new_int(libc::SIGSYS as u8), }); - #[cfg(not(target_os = "macos"))] + #[cfg(not(any(target_os = "macos", target_os = "openbsd")))] { extend_module!(vm, module, { "SIGPWR" => ctx.new_int(libc::SIGPWR as u8), diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 2969ae5a7a..2e21b92d9b 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -1,9 +1,10 @@ -use std::cell::{Cell, Ref, RefCell}; use std::io::{self, prelude::*}; use std::net::{Ipv4Addr, Shutdown, SocketAddr, ToSocketAddrs}; +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use std::time::Duration; use byteorder::{BigEndian, ByteOrder}; +use crossbeam_utils::atomic::AtomicCell; use gethostname::gethostname; #[cfg(all(unix, not(target_os = "redox")))] use nix::unistd::sethostname; @@ -44,10 +45,10 @@ mod c { #[pyclass] #[derive(Debug)] pub struct PySocket { - kind: Cell, - family: Cell, - proto: Cell, - sock: RefCell, + kind: AtomicCell, + family: AtomicCell, + proto: AtomicCell, + sock: RwLock, } impl PyValue for PySocket { @@ -60,17 +61,21 @@ pub type PySocketRef = PyRef; #[pyimpl(flags(BASETYPE))] impl PySocket { - fn sock(&self) -> Ref { - self.sock.borrow() + fn sock(&self) -> RwLockReadGuard<'_, Socket> { + self.sock.read().unwrap() + } + + fn sock_mut(&self) -> RwLockWriteGuard<'_, Socket> { + self.sock.write().unwrap() } #[pyslot] fn tp_new(cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult> { PySocket { - kind: Cell::default(), - family: Cell::default(), - proto: Cell::default(), - sock: RefCell::new(invalid_sock()), + kind: AtomicCell::default(), + family: AtomicCell::default(), + proto: AtomicCell::default(), + sock: RwLock::new(invalid_sock()), } .into_ref_with_type(vm, cls) } @@ -103,12 +108,12 @@ impl PySocket { ) .map_err(|err| convert_sock_error(vm, err))?; - self.family.set(family); - self.kind.set(socket_kind); - self.proto.set(proto); + self.family.store(family); + self.kind.store(socket_kind); + self.proto.store(proto); sock }; - self.sock.replace(sock); + *self.sock.write().unwrap() = sock; Ok(()) } @@ -191,7 +196,7 @@ impl PySocket { #[pymethod] fn sendall(&self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult<()> { bytes - .with_ref(|b| self.sock.borrow_mut().write_all(b)) + .with_ref(|b| self.sock_mut().write_all(b)) .map_err(|err| convert_sock_error(vm, err)) } @@ -206,11 +211,11 @@ impl PySocket { #[pymethod] fn close(&self) { - self.sock.replace(invalid_sock()); + *self.sock_mut() = invalid_sock(); } #[pymethod] fn detach(&self) -> RawSocket { - into_sock_fileno(self.sock.replace(invalid_sock())) + into_sock_fileno(std::mem::replace(&mut *self.sock_mut(), invalid_sock())) } #[pymethod] @@ -384,29 +389,29 @@ impl PySocket { #[pyproperty(name = "type")] fn kind(&self) -> i32 { - self.kind.get() + self.kind.load() } #[pyproperty] fn family(&self) -> i32 { - self.family.get() + self.family.load() } #[pyproperty] fn proto(&self) -> i32 { - self.proto.get() + self.proto.load() } } impl io::Read for PySocketRef { fn read(&mut self, buf: &mut [u8]) -> io::Result { - ::read(&mut self.sock.borrow_mut(), buf) + ::read(&mut self.sock_mut(), buf) } } impl io::Write for PySocketRef { fn write(&mut self, buf: &[u8]) -> io::Result { - ::write(&mut self.sock.borrow_mut(), buf) + ::write(&mut self.sock_mut(), buf) } fn flush(&mut self) -> io::Result<()> { - ::flush(&mut self.sock.borrow_mut()) + ::flush(&mut self.sock_mut()) } } diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 7cb133b93c..02c52a427b 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -11,10 +11,10 @@ use crate::pyobject::{ use crate::types::create_type; use crate::VirtualMachine; -use std::cell::{Ref, RefCell, RefMut}; use std::convert::TryFrom; use std::ffi::{CStr, CString}; use std::fmt; +use std::sync::{RwLock, RwLockWriteGuard}; use foreign_types_shared::{ForeignType, ForeignTypeRef}; use openssl::{ @@ -230,7 +230,7 @@ fn ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec, bool #[pyclass(name = "_SSLContext")] struct PySslContext { - ctx: RefCell, + ctx: RwLock, check_hostname: bool, } @@ -248,16 +248,18 @@ impl PyValue for PySslContext { #[pyimpl(flags(BASETYPE))] impl PySslContext { - fn builder(&self) -> RefMut { - self.ctx.borrow_mut() + fn builder(&self) -> RwLockWriteGuard<'_, SslContextBuilder> { + self.ctx.write().unwrap() } - fn ctx(&self) -> Ref { - Ref::map(self.ctx.borrow(), |ctx| unsafe { - &**(ctx as *const SslContextBuilder as *const ssl::SslContext) - }) + fn exec_ctx(&self, func: F) -> R + where + F: Fn(&ssl::SslContextRef) -> R, + { + let c = self.ctx.read().unwrap(); + func(unsafe { &**(&*c as *const SslContextBuilder as *const ssl::SslContext) }) } fn ptr(&self) -> *mut sys::SSL_CTX { - self.ctx.borrow().as_ptr() + (*self.ctx.write().unwrap()).as_ptr() } #[pyslot] @@ -306,7 +308,7 @@ impl PySslContext { .map_err(|e| convert_openssl_error(vm, e))?; PySslContext { - ctx: RefCell::new(builder), + ctx: RwLock::new(builder), check_hostname, } .into_ref_with_type(vm, cls) @@ -325,7 +327,7 @@ impl PySslContext { #[pyproperty] fn verify_mode(&self) -> i32 { - let mode = self.ctx().verify_mode(); + let mode = self.exec_ctx(|ctx| ctx.verify_mode()); if mode == SslVerifyMode::NONE { CertRequirements::None.into() } else if mode == SslVerifyMode::PEER { @@ -385,9 +387,10 @@ impl PySslContext { Either::B(b) => b.with_ref(X509::from_der), }; let cert = cert.map_err(|e| convert_openssl_error(vm, e))?; - let ctx = self.ctx(); - let store = ctx.cert_store(); - let ret = unsafe { sys::X509_STORE_add_cert(store.as_ptr(), cert.as_ptr()) }; + let ret = self.exec_ctx(|ctx| { + let store = ctx.cert_store(); + unsafe { sys::X509_STORE_add_cert(store.as_ptr(), cert.as_ptr()) } + }); if ret <= 0 { return Err(convert_openssl_error(vm, ErrorStack::get())); } @@ -424,7 +427,8 @@ impl PySslContext { use openssl::stack::StackRef; let binary_form = binary_form.unwrap_or(false); let certs = unsafe { - let stack = sys::X509_STORE_get0_objects(self.ctx().cert_store().as_ptr()); + let stack = + sys::X509_STORE_get0_objects(self.exec_ctx(|ctx| ctx.cert_store().as_ptr())); assert!(!stack.is_null()); StackRef::::from_ptr(stack) }; @@ -467,10 +471,10 @@ impl PySslContext { Ok(PySslSocket { ctx: zelf, - stream: RefCell::new(Some(stream)), + stream: RwLock::new(Some(stream)), socket_type, server_hostname: args.server_hostname, - owner: RefCell::new(args.owner.as_ref().map(PyWeak::downgrade)), + owner: RwLock::new(args.owner.as_ref().map(PyWeak::downgrade)), }) } } @@ -503,10 +507,10 @@ struct LoadVerifyLocationsArgs { #[pyclass(name = "_SSLSocket")] struct PySslSocket { ctx: PyRef, - stream: RefCell>>, + stream: RwLock>>, socket_type: SslServerOrClient, server_hostname: Option, - owner: RefCell>, + owner: RwLock>, } impl fmt::Debug for PySslSocket { @@ -524,28 +528,32 @@ impl PyValue for PySslSocket { #[pyimpl] impl PySslSocket { fn stream_builder(&self) -> ssl::SslStreamBuilder { - self.stream.replace(None).unwrap() - } - fn stream(&self) -> RefMut> { - RefMut::map(self.stream.borrow_mut(), |b| { - let b = b.as_mut().unwrap(); - unsafe { &mut *(b as *mut ssl::SslStreamBuilder<_> as *mut ssl::SslStream<_>) } + std::mem::replace(&mut *self.stream.write().unwrap(), None).unwrap() + } + fn exec_stream(&self, func: F) -> R + where + F: Fn(&mut ssl::SslStream) -> R, + { + let mut b = self.stream.write().unwrap(); + func(unsafe { + &mut *(b.as_mut().unwrap() as *mut ssl::SslStreamBuilder<_> as *mut ssl::SslStream<_>) }) } fn set_stream(&self, stream: ssl::SslStream) { - let prev = self - .stream - .replace(Some(unsafe { std::mem::transmute(stream) })); - debug_assert!(prev.is_none()); + *self.stream.write().unwrap() = Some(unsafe { std::mem::transmute(stream) }); } #[pyproperty] fn owner(&self) -> Option { - self.owner.borrow().as_ref().and_then(PyWeak::upgrade) + self.owner + .read() + .unwrap() + .as_ref() + .and_then(PyWeak::upgrade) } #[pyproperty(setter)] fn set_owner(&self, owner: PyObjectRef) { - *self.owner.borrow_mut() = Some(PyWeak::downgrade(&owner)) + *self.owner.write().unwrap() = Some(PyWeak::downgrade(&owner)) } #[pyproperty] fn server_side(&self) -> bool { @@ -567,12 +575,10 @@ impl PySslSocket { vm: &VirtualMachine, ) -> PyResult> { let binary = binary.unwrap_or(false); - if !self.stream().ssl().is_init_finished() { + if !self.exec_stream(|stream| stream.ssl().is_init_finished()) { return Err(vm.new_value_error("handshake not done yet".to_owned())); } - self.stream() - .ssl() - .peer_certificate() + self.exec_stream(|stream| stream.ssl().peer_certificate()) .map(|cert| cert_to_py(vm, &cert, binary)) .transpose() } @@ -605,17 +611,18 @@ impl PySslSocket { #[pymethod] fn write(&self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult { - data.with_ref(|b| self.stream().ssl_write(b)) + data.with_ref(|b| self.exec_stream(|stream| stream.ssl_write(b))) .map_err(|e| convert_ssl_error(vm, e)) } #[pymethod] fn read(&self, n: usize, buffer: OptionalArg, vm: &VirtualMachine) -> PyResult { if let OptionalArg::Present(buffer) = buffer { - let mut buf = buffer.borrow_value_mut(); let n = self - .stream() - .ssl_read(&mut buf.elements) + .exec_stream(|stream| { + let mut buf = buffer.borrow_value_mut(); + stream.ssl_read(&mut buf.elements) + }) .map_err(|e| convert_ssl_error(vm, e))?; Ok(vm.new_int(n)) } else { diff --git a/vm/src/stdlib/subprocess.rs b/vm/src/stdlib/subprocess.rs index 9694edb82a..96714398c0 100644 --- a/vm/src/stdlib/subprocess.rs +++ b/vm/src/stdlib/subprocess.rs @@ -1,7 +1,7 @@ -use std::cell::RefCell; use std::ffi::OsString; use std::fs::File; use std::io::ErrorKind; +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use std::time::Duration; use crate::function::OptionalArg; @@ -16,10 +16,14 @@ use crate::vm::VirtualMachine; #[derive(Debug)] struct Popen { - process: RefCell, + process: RwLock, args: PyObjectRef, } +// Remove once https://github.com/hniksic/rust-subprocess/issues/42 is resolved +#[cfg(windows)] +unsafe impl Sync for Popen {} + impl PyValue for Popen { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("_subprocess", "Popen") @@ -103,6 +107,14 @@ fn convert_to_file_io(file: &Option, mode: &str, vm: &VirtualMachine) -> P } impl PopenRef { + fn borrow_process(&self) -> RwLockReadGuard<'_, subprocess::Popen> { + self.process.read().unwrap() + } + + fn borrow_process_mut(&self) -> RwLockWriteGuard<'_, subprocess::Popen> { + self.process.write().unwrap() + } + fn new(cls: PyClassRef, args: PopenArgs, vm: &VirtualMachine) -> PyResult { let stdin = convert_redirection(args.stdin, vm)?; let stdout = convert_redirection(args.stdout, vm)?; @@ -130,27 +142,26 @@ impl PopenRef { .map_err(|s| vm.new_os_error(format!("Could not start program: {}", s)))?; Popen { - process: RefCell::new(process), + process: RwLock::new(process), args: args.args.into_object(), } .into_ref_with_type(vm, cls) } fn poll(self) -> Option { - self.process.borrow_mut().poll() + self.borrow_process_mut().poll() } fn return_code(self) -> Option { - self.process.borrow().exit_status() + self.borrow_process().exit_status() } fn wait(self, args: PopenWaitArgs, vm: &VirtualMachine) -> PyResult { let timeout = match args.timeout { Some(timeout) => self - .process - .borrow_mut() + .borrow_process_mut() .wait_timeout(Duration::new(timeout, 0)), - None => self.process.borrow_mut().wait().map(Some), + None => self.borrow_process_mut().wait().map(Some), } .map_err(|s| vm.new_os_error(format!("Could not start program: {}", s)))?; if let Some(exit) = timeout { @@ -167,27 +178,25 @@ impl PopenRef { } fn stdin(self, vm: &VirtualMachine) -> PyResult { - convert_to_file_io(&self.process.borrow().stdin, "wb", vm) + convert_to_file_io(&self.borrow_process().stdin, "wb", vm) } fn stdout(self, vm: &VirtualMachine) -> PyResult { - convert_to_file_io(&self.process.borrow().stdout, "rb", vm) + convert_to_file_io(&self.borrow_process().stdout, "rb", vm) } fn stderr(self, vm: &VirtualMachine) -> PyResult { - convert_to_file_io(&self.process.borrow().stderr, "rb", vm) + convert_to_file_io(&self.borrow_process().stderr, "rb", vm) } fn terminate(self, vm: &VirtualMachine) -> PyResult<()> { - self.process - .borrow_mut() + self.borrow_process_mut() .terminate() .map_err(|err| convert_io_error(vm, err)) } fn kill(self, vm: &VirtualMachine) -> PyResult<()> { - self.process - .borrow_mut() + self.borrow_process_mut() .kill() .map_err(|err| convert_io_error(vm, err)) } @@ -202,7 +211,7 @@ impl PopenRef { OptionalArg::Present(ref bytes) => Some(bytes.get_value().to_vec()), OptionalArg::Missing => None, }; - let mut communicator = self.process.borrow_mut().communicate_start(bytes); + let mut communicator = self.borrow_process_mut().communicate_start(bytes); if let OptionalArg::Present(timeout) = args.timeout { communicator = communicator.limit_time(Duration::new(timeout, 0)); } @@ -217,7 +226,7 @@ impl PopenRef { } fn pid(self) -> Option { - self.process.borrow().pid() + self.borrow_process().pid() } fn enter(self) -> Self { @@ -230,7 +239,7 @@ impl PopenRef { _exception_value: PyObjectRef, _traceback: PyObjectRef, ) { - let mut process = self.process.borrow_mut(); + let mut process = self.borrow_process_mut(); process.stdout.take(); process.stdin.take(); process.stderr.take(); diff --git a/vm/src/stdlib/thread.rs b/vm/src/stdlib/thread.rs index cf9c8749be..1a3ec9a94d 100644 --- a/vm/src/stdlib/thread.rs +++ b/vm/src/stdlib/thread.rs @@ -1,62 +1,282 @@ -/// Implementation of the _thread module, currently noop implementation as RustPython doesn't yet -/// support threading -use crate::function::PyFuncArgs; -use crate::pyobject::{PyObjectRef, PyResult}; +/// Implementation of the _thread module +use crate::exceptions; +use crate::function::{Args, KwArgs, OptionalArg, PyFuncArgs}; +use crate::obj::objdict::PyDictRef; +use crate::obj::objtuple::PyTupleRef; +use crate::obj::objtype::PyClassRef; +use crate::pyobject::{ + Either, IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, + TypeProtocol, +}; use crate::vm::VirtualMachine; +use parking_lot::{ + lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, + RawMutex, RawThreadId, +}; +use std::cell::RefCell; +use std::io::Write; +use std::time::Duration; +use std::{fmt, thread}; + +// PY_TIMEOUT_MAX is a value in microseconds #[cfg(not(target_os = "windows"))] -const PY_TIMEOUT_MAX: isize = std::isize::MAX; +const PY_TIMEOUT_MAX: isize = std::isize::MAX / 1_000; #[cfg(target_os = "windows")] -const PY_TIMEOUT_MAX: isize = 0xffffffff * 1_000_000; +const PY_TIMEOUT_MAX: isize = 0xffffffff * 1_000; + +// this is a value in seconds +const TIMEOUT_MAX: f64 = (PY_TIMEOUT_MAX / 1_000_000) as f64; + +#[derive(FromArgs)] +struct AcquireArgs { + #[pyarg(positional_or_keyword, default = "true")] + blocking: bool, + #[pyarg(positional_or_keyword, default = "Either::A(-1.0)")] + timeout: Either, +} + +macro_rules! acquire_lock_impl { + ($mu:expr, $args:expr, $vm:expr) => {{ + let (mu, args, vm) = ($mu, $args, $vm); + let timeout = match args.timeout { + Either::A(f) => f, + Either::B(i) => i as f64, + }; + match args.blocking { + true if timeout == -1.0 => { + mu.lock(); + Ok(true) + } + true if timeout < 0.0 => { + Err(vm.new_value_error("timeout value must be positive".to_owned())) + } + true => { + // modified from std::time::Duration::from_secs_f64 to avoid a panic. + // TODO: put this in the Duration::try_from_object impl, maybe? + let micros = timeout * 1_000_000.0; + let nanos = timeout * 1_000_000_000.0; + if micros > PY_TIMEOUT_MAX as f64 || nanos < 0.0 || !nanos.is_finite() { + return Err(vm.new_overflow_error( + "timestamp too large to convert to Rust Duration".to_owned(), + )); + } -const TIMEOUT_MAX: f64 = (PY_TIMEOUT_MAX / 1_000_000_000) as f64; + Ok(mu.try_lock_for(Duration::from_secs_f64(timeout))) + } + false if timeout != -1.0 => { + Err(vm + .new_value_error("can't specify a timeout for a non-blocking call".to_owned())) + } + false => Ok(mu.try_lock()), + } + }}; +} +macro_rules! repr_lock_impl { + ($zelf:expr) => {{ + let status = if $zelf.mu.is_locked() { + "locked" + } else { + "unlocked" + }; + format!( + "<{} {} object at {}>", + status, + $zelf.class().name, + $zelf.get_id() + ) + }}; +} -fn rlock_acquire(vm: &VirtualMachine, _args: PyFuncArgs) -> PyResult { - Ok(vm.get_none()) +#[pyclass(name = "lock")] +struct PyLock { + mu: RawMutex, } +type PyLockRef = PyRef; -fn rlock_release(_zelf: PyObjectRef) {} +impl PyValue for PyLock { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_thread", "LockType") + } +} + +impl fmt::Debug for PyLock { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad("PyLock") + } +} + +#[pyimpl] +impl PyLock { + #[pymethod] + #[pymethod(name = "acquire_lock")] + #[pymethod(name = "__enter__")] + #[allow(clippy::float_cmp, clippy::match_bool)] + fn acquire(&self, args: AcquireArgs, vm: &VirtualMachine) -> PyResult { + acquire_lock_impl!(&self.mu, args, vm) + } + #[pymethod] + #[pymethod(name = "release_lock")] + fn release(&self) { + self.mu.unlock() + } + + #[pymethod(magic)] + fn exit(&self, _args: PyFuncArgs) { + self.release() + } + + #[pymethod] + fn locked(&self) -> bool { + self.mu.is_locked() + } + + #[pymethod(magic)] + fn repr(zelf: PyRef) -> String { + repr_lock_impl!(zelf) + } +} + +type RawRMutex = RawReentrantMutex; +#[pyclass(name = "RLock")] +struct PyRLock { + mu: RawRMutex, +} + +impl PyValue for PyRLock { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_thread", "RLock") + } +} -fn rlock_enter(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(instance, None)]); - Ok(instance.clone()) +impl fmt::Debug for PyRLock { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad("PyRLock") + } } -fn rlock_exit( - // The context manager protocol requires these, but we don't use them - _instance: PyObjectRef, - _exception_type: PyObjectRef, - _exception_value: PyObjectRef, - _traceback: PyObjectRef, +#[pyimpl] +impl PyRLock { + #[pyslot] + fn tp_new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult> { + PyRLock { + mu: RawRMutex::INIT, + } + .into_ref_with_type(vm, cls) + } + + #[pymethod] + #[pymethod(name = "acquire_lock")] + #[pymethod(name = "__enter__")] + #[allow(clippy::float_cmp, clippy::match_bool)] + fn acquire(&self, args: AcquireArgs, vm: &VirtualMachine) -> PyResult { + acquire_lock_impl!(&self.mu, args, vm) + } + #[pymethod] + #[pymethod(name = "release_lock")] + fn release(&self) { + self.mu.unlock() + } + + #[pymethod(magic)] + fn exit(&self, _args: PyFuncArgs) { + self.release() + } + + #[pymethod(magic)] + fn repr(zelf: PyRef) -> String { + repr_lock_impl!(zelf) + } +} + +fn thread_get_ident() -> u64 { + thread_to_id(&thread::current()) +} + +fn thread_to_id(t: &thread::Thread) -> u64 { + // TODO: use id.as_u64() once it's stable, until then, ThreadId is just a wrapper + // around NonZeroU64, so this is safe + unsafe { std::mem::transmute(t.id()) } +} + +fn thread_allocate_lock() -> PyLock { + PyLock { mu: RawMutex::INIT } +} + +fn thread_start_new_thread( + func: PyCallable, + args: PyTupleRef, + kwargs: OptionalArg, vm: &VirtualMachine, -) -> PyResult { - Ok(vm.get_none()) +) -> PyResult { + let thread_vm = vm.new_thread(); + let mut thread_builder = thread::Builder::new(); + let stacksize = vm.state.stacksize.load(); + if stacksize != 0 { + thread_builder = thread_builder.stack_size(stacksize); + } + let res = thread_builder.spawn(move || { + let vm = &thread_vm; + let args = Args::from(args.as_slice().to_owned()); + let kwargs = KwArgs::from(kwargs.map_or_else(Default::default, |k| k.to_attributes())); + if let Err(exc) = func.invoke(PyFuncArgs::from((args, kwargs)), vm) { + // TODO: sys.unraisablehook + let stderr = std::io::stderr(); + let mut stderr = stderr.lock(); + let repr = vm.to_repr(&func.into_object()).ok(); + let repr = repr + .as_ref() + .map_or("", |s| s.as_str()); + writeln!(stderr, "Exception ignored in thread started by: {}", repr) + .and_then(|()| exceptions::write_exception(&mut stderr, vm, &exc)) + .ok(); + } + SENTINELS.with(|sents| { + for lock in sents.replace(Default::default()) { + lock.mu.unlock() + } + }); + vm.state.thread_count.fetch_sub(1); + }); + res.map(|handle| { + vm.state.thread_count.fetch_add(1); + thread_to_id(&handle.thread()) + }) + .map_err(|err| super::os::convert_io_error(vm, err)) } -fn get_ident(_vm: &VirtualMachine) -> u32 { - 1 +thread_local!(static SENTINELS: RefCell> = RefCell::default()); + +fn thread_set_sentinel(vm: &VirtualMachine) -> PyLockRef { + let lock = PyLock { mu: RawMutex::INIT }.into_ref(vm); + SENTINELS.with(|sents| sents.borrow_mut().push(lock.clone())); + lock } -fn allocate_lock(vm: &VirtualMachine) -> PyResult { - let lock_class = vm.class("_thread", "RLock"); - vm.invoke(&lock_class.into_object(), vec![]) +fn thread_stack_size(size: OptionalArg, vm: &VirtualMachine) -> usize { + let size = size.unwrap_or(0); + // TODO: do validation on this to make sure it's not too small + vm.state.stacksize.swap(size) +} + +fn thread_count(vm: &VirtualMachine) -> usize { + vm.state.thread_count.load() } pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; - let rlock_type = py_class!(ctx, "_thread.RLock", ctx.object(), { - "acquire" => ctx.new_method(rlock_acquire), - "release" => ctx.new_method(rlock_release), - "__enter__" => ctx.new_method(rlock_enter), - "__exit__" => ctx.new_method(rlock_exit), - }); - py_module!(vm, "_thread", { - "RLock" => rlock_type, - "get_ident" => ctx.new_function(get_ident), - "allocate_lock" => ctx.new_function(allocate_lock), + "RLock" => PyRLock::make_class(ctx), + "LockType" => PyLock::make_class(ctx), + "get_ident" => ctx.new_function(thread_get_ident), + "allocate_lock" => ctx.new_function(thread_allocate_lock), + "start_new_thread" => ctx.new_function(thread_start_new_thread), + "_set_sentinel" => ctx.new_function(thread_set_sentinel), + "stack_size" => ctx.new_function(thread_stack_size), + "_count" => ctx.new_function(thread_count), + "error" => ctx.exceptions.runtime_error.clone(), "TIMEOUT_MAX" => ctx.new_float(TIMEOUT_MAX), }) } diff --git a/vm/src/stdlib/time_module.rs b/vm/src/stdlib/time_module.rs index a807535522..f697d4a49c 100644 --- a/vm/src/stdlib/time_module.rs +++ b/vm/src/stdlib/time_module.rs @@ -70,32 +70,41 @@ fn time_monotonic(_vm: &VirtualMachine) -> f64 { } } -fn pyobj_to_naive_date_time(value: Either) -> NaiveDateTime { - match value { +fn pyobj_to_naive_date_time( + value: Either, + vm: &VirtualMachine, +) -> PyResult { + let timestamp = match value { Either::A(float) => { let secs = float.trunc() as i64; let nsecs = (float.fract() * 1e9) as u32; - NaiveDateTime::from_timestamp(secs, nsecs) + NaiveDateTime::from_timestamp_opt(secs, nsecs) } - Either::B(int) => NaiveDateTime::from_timestamp(int, 0), - } + Either::B(int) => NaiveDateTime::from_timestamp_opt(int, 0), + }; + timestamp.ok_or_else(|| { + vm.new_overflow_error("timestamp out of range for platform time_t".to_owned()) + }) } /// https://docs.python.org/3/library/time.html?highlight=gmtime#time.gmtime -fn time_gmtime(secs: OptionalArg>, vm: &VirtualMachine) -> PyObjectRef { +fn time_gmtime(secs: OptionalArg>, vm: &VirtualMachine) -> PyResult { let default = chrono::offset::Utc::now().naive_utc(); let instant = match secs { - OptionalArg::Present(secs) => pyobj_to_naive_date_time(secs), + OptionalArg::Present(secs) => pyobj_to_naive_date_time(secs, vm)?, OptionalArg::Missing => default, }; - PyStructTime::new(vm, instant, 0).into_obj(vm) + Ok(PyStructTime::new(vm, instant, 0).into_obj(vm)) } -fn time_localtime(secs: OptionalArg>, vm: &VirtualMachine) -> PyObjectRef { - let instant = optional_or_localtime(secs); +fn time_localtime( + secs: OptionalArg>, + vm: &VirtualMachine, +) -> PyResult { + let instant = optional_or_localtime(secs, vm)?; // TODO: isdst flag must be valid value here // https://docs.python.org/3/library/time.html#time.localtime - PyStructTime::new(vm, instant, -1).into_obj(vm) + Ok(PyStructTime::new(vm, instant, -1).into_obj(vm)) } fn time_mktime(t: PyStructTime, vm: &VirtualMachine) -> PyResult { @@ -105,12 +114,15 @@ fn time_mktime(t: PyStructTime, vm: &VirtualMachine) -> PyResult { } /// Construct a localtime from the optional seconds, or get the current local time. -fn optional_or_localtime(secs: OptionalArg>) -> NaiveDateTime { +fn optional_or_localtime( + secs: OptionalArg>, + vm: &VirtualMachine, +) -> PyResult { let default = chrono::offset::Local::now().naive_local(); - match secs { - OptionalArg::Present(secs) => pyobj_to_naive_date_time(secs), + Ok(match secs { + OptionalArg::Present(secs) => pyobj_to_naive_date_time(secs, vm)?, OptionalArg::Missing => default, - } + }) } const CFMT: &str = "%a %b %e %H:%M:%S %Y"; @@ -125,9 +137,9 @@ fn time_asctime(t: OptionalArg, vm: &VirtualMachine) -> PyResult { Ok(vm.ctx.new_str(formatted_time)) } -fn time_ctime(secs: OptionalArg>) -> String { - let instant = optional_or_localtime(secs); - instant.format(&CFMT).to_string() +fn time_ctime(secs: OptionalArg>, vm: &VirtualMachine) -> PyResult { + let instant = optional_or_localtime(secs, vm)?; + Ok(instant.format(&CFMT).to_string()) } fn time_strftime( diff --git a/vm/src/stdlib/winreg.rs b/vm/src/stdlib/winreg.rs index e08eabc8cc..dba3285db8 100644 --- a/vm/src/stdlib/winreg.rs +++ b/vm/src/stdlib/winreg.rs @@ -1,8 +1,7 @@ #![allow(non_snake_case)] - -use std::cell::{Ref, RefCell}; use std::convert::TryInto; use std::io; +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use super::os; use crate::function::OptionalArg; @@ -17,10 +16,13 @@ use winreg::{enums::RegType, RegKey, RegValue}; #[pyclass] #[derive(Debug)] struct PyHKEY { - key: RefCell, + key: RwLock, } type PyHKEYRef = PyRef; +// TODO: fix this +unsafe impl Sync for PyHKEY {} + impl PyValue for PyHKEY { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("winreg", "HKEYType") @@ -31,24 +33,28 @@ impl PyValue for PyHKEY { impl PyHKEY { fn new(key: RegKey) -> Self { Self { - key: RefCell::new(key), + key: RwLock::new(key), } } - fn key(&self) -> Ref { - self.key.borrow() + fn key(&self) -> RwLockReadGuard<'_, RegKey> { + self.key.read().unwrap() + } + + fn key_mut(&self) -> RwLockWriteGuard<'_, RegKey> { + self.key.write().unwrap() } #[pymethod] fn Close(&self) { let null_key = RegKey::predef(0 as winreg::HKEY); - let key = self.key.replace(null_key); + let key = std::mem::replace(&mut *self.key_mut(), null_key); drop(key); } #[pymethod] fn Detach(&self) -> usize { let null_key = RegKey::predef(0 as winreg::HKEY); - let key = self.key.replace(null_key); + let key = std::mem::replace(&mut *self.key_mut(), null_key); let handle = key.raw_handle(); std::mem::forget(key); handle as usize diff --git a/vm/src/sysmodule.rs b/vm/src/sysmodule.rs index 73d360d806..007d7ef8e8 100644 --- a/vm/src/sysmodule.rs +++ b/vm/src/sysmodule.rs @@ -18,7 +18,8 @@ use crate::vm::{PySettings, VirtualMachine}; fn argv(vm: &VirtualMachine) -> PyObjectRef { vm.ctx.new_list( - vm.settings + vm.state + .settings .argv .iter() .map(|arg| vm.new_str(arg.to_owned())) @@ -150,7 +151,7 @@ fn update_use_tracing(vm: &VirtualMachine) { let trace_is_none = vm.is_none(&vm.trace_func.borrow()); let profile_is_none = vm.is_none(&vm.profile_func.borrow()); let tracing = !(trace_is_none && profile_is_none); - vm.use_tracing.replace(tracing); + vm.use_tracing.set(tracing); } fn sys_getrecursionlimit(vm: &VirtualMachine) -> usize { @@ -225,7 +226,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef, builtins: PyObjectR let ctx = &vm.ctx; let flags_type = SysFlags::make_class(ctx); - let flags = SysFlags::from_settings(&vm.settings) + let flags = SysFlags::from_settings(&vm.state.settings) .into_struct_sequence(vm, flags_type) .unwrap(); @@ -246,7 +247,8 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef, builtins: PyObjectR }); let path = ctx.new_list( - vm.settings + vm.state + .settings .path_list .iter() .map(|path| ctx.new_str(path.clone())) @@ -348,7 +350,7 @@ setprofile() -- set the global profiling function setrecursionlimit() -- set the max recursion depth for the interpreter settrace() -- set the global debug tracing function "; - let mut module_names: Vec = vm.stdlib_inits.borrow().keys().cloned().collect(); + let mut module_names: Vec = vm.state.stdlib_inits.keys().cloned().collect(); module_names.push("sys".to_owned()); module_names.push("builtins".to_owned()); module_names.sort(); @@ -399,7 +401,7 @@ settrace() -- set the global debug tracing function "path_hooks" => ctx.new_list(vec![]), "path_importer_cache" => ctx.new_dict(), "pycache_prefix" => vm.get_none(), - "dont_write_bytecode" => vm.new_bool(vm.settings.dont_write_bytecode), + "dont_write_bytecode" => vm.new_bool(vm.state.settings.dont_write_bytecode), "setprofile" => ctx.new_function(sys_setprofile), "setrecursionlimit" => ctx.new_function(sys_setrecursionlimit), "settrace" => ctx.new_function(sys_settrace), diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 1d9b68c262..5a7006386c 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -11,6 +11,7 @@ use std::sync::{Arc, Mutex, MutexGuard}; use std::{env, fmt}; use arr_macro::arr; +use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::ToPrimitive; use once_cell::sync::Lazy; @@ -56,23 +57,28 @@ use crate::sysmodule; pub struct VirtualMachine { pub builtins: PyObjectRef, pub sys_module: PyObjectRef, - pub stdlib_inits: RefCell>, - pub ctx: PyContext, + pub ctx: Arc, pub frames: RefCell>, pub wasm_id: Option, pub exceptions: RefCell>, - pub frozen: RefCell>, - pub import_func: RefCell, + pub import_func: PyObjectRef, pub profile_func: RefCell, pub trace_func: RefCell, - pub use_tracing: RefCell, - pub signal_handlers: RefCell<[PyObjectRef; NSIG]>, - pub settings: PySettings, + pub use_tracing: Cell, pub recursion_limit: Cell, - pub codec_registry: RefCell>, + pub signal_handlers: Option>>, + pub state: Arc, pub initialized: bool, } +pub struct PyGlobalState { + pub settings: PySettings, + pub stdlib_inits: HashMap, + pub frozen: HashMap, + pub stacksize: AtomicCell, + pub thread_count: AtomicCell, +} + pub const NSIG: usize = 64; #[derive(Copy, Clone)] @@ -175,31 +181,35 @@ impl VirtualMachine { let sysmod_dict = ctx.new_dict(); let sysmod = new_module(sysmod_dict.clone()); - let stdlib_inits = RefCell::new(stdlib::get_module_inits()); - let frozen = RefCell::new(frozen::get_module_inits()); - let import_func = RefCell::new(ctx.none()); + let import_func = ctx.none(); let profile_func = RefCell::new(ctx.none()); let trace_func = RefCell::new(ctx.none()); let signal_handlers = RefCell::new(arr![ctx.none(); 64]); let initialize_parameter = settings.initialization_parameter; + let stdlib_inits = stdlib::get_module_inits(); + let frozen = frozen::get_module_inits(); + let mut vm = VirtualMachine { builtins: builtins.clone(), sys_module: sysmod.clone(), - stdlib_inits, - ctx, + ctx: Arc::new(ctx), frames: RefCell::new(vec![]), wasm_id: None, exceptions: RefCell::new(vec![]), - frozen, import_func, profile_func, trace_func, - use_tracing: RefCell::new(false), - signal_handlers, - settings, + use_tracing: Cell::new(false), recursion_limit: Cell::new(if cfg!(debug_assertions) { 256 } else { 512 }), - codec_registry: RefCell::default(), + signal_handlers: Some(Box::new(signal_handlers)), + state: Arc::new(PyGlobalState { + settings, + stdlib_inits, + frozen, + stacksize: AtomicCell::new(0), + thread_count: AtomicCell::new(0), + }), initialized: false, }; @@ -232,7 +242,7 @@ impl VirtualMachine { builtins::make_module(self, self.builtins.clone()); sysmodule::make_module(self, self.sys_module.clone(), self.builtins.clone()); - let inner_init = || -> PyResult<()> { + let mut inner_init = || -> PyResult<()> { #[cfg(not(target_arch = "wasm32"))] import::import_builtin(self, "signal")?; @@ -240,7 +250,10 @@ impl VirtualMachine { #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] { - let io = self.import("io", &[], 0)?; + // this isn't fully compatible with CPython; it imports "io" and sets + // builtins.open to io.OpenWrapper, but this is easier, since it doesn't + // require the Python stdlib to be present + let io = self.import("_io", &[], 0)?; let io_open = self.get_attribute(io.clone(), "open")?; let set_stdio = |name, fd, mode: &str| { let stdio = self.invoke( @@ -259,20 +272,40 @@ impl VirtualMachine { set_stdio("stdout", 1, "w")?; set_stdio("stderr", 2, "w")?; - let open_wrapper = self.get_attribute(io, "OpenWrapper")?; - self.set_attr(&self.builtins, "open", open_wrapper)?; + self.set_attr(&self.builtins, "open", io_open)?; } Ok(()) }; - self.expect_pyresult(inner_init(), "initializiation failed"); + let res = inner_init(); + + self.expect_pyresult(res, "initializiation failed"); self.initialized = true; } } } + pub(crate) fn new_thread(&self) -> VirtualMachine { + VirtualMachine { + builtins: self.builtins.clone(), + sys_module: self.sys_module.clone(), + ctx: self.ctx.clone(), + frames: RefCell::new(vec![]), + wasm_id: self.wasm_id.clone(), + exceptions: RefCell::new(vec![]), + import_func: self.import_func.clone(), + profile_func: RefCell::new(self.get_none()), + trace_func: RefCell::new(self.get_none()), + use_tracing: Cell::new(false), + recursion_limit: self.recursion_limit.clone(), + signal_handlers: None, + state: self.state.clone(), + initialized: self.initialized, + } + } + pub fn run_code_obj(&self, code: PyCodeRef, scope: Scope) -> PyResult { let frame = Frame::new(code, scope).into_ref(self); self.run_frame_full(frame) @@ -814,7 +847,7 @@ impl VirtualMachine { /// Call registered trace function. fn trace_event(&self, event: TraceEvent) -> PyResult<()> { - if *self.use_tracing.borrow() { + if self.use_tracing.get() { let frame = self.get_none(); let event = self.new_str(event.to_string()); let arg = self.get_none(); @@ -824,17 +857,17 @@ impl VirtualMachine { // tracing function itself. let trace_func = self.trace_func.borrow().clone(); if !self.is_none(&trace_func) { - self.use_tracing.replace(false); + self.use_tracing.set(false); let res = self.invoke(&trace_func, args.clone()); - self.use_tracing.replace(true); + self.use_tracing.set(true); res?; } let profile_func = self.profile_func.borrow().clone(); if !self.is_none(&profile_func) { - self.use_tracing.replace(false); + self.use_tracing.set(false); let res = self.invoke(&profile_func, args); - self.use_tracing.replace(true); + self.use_tracing.set(true); res?; } } @@ -1033,7 +1066,7 @@ impl VirtualMachine { #[cfg(feature = "rustpython-compiler")] pub fn compile_opts(&self) -> CompileOpts { CompileOpts { - optimize: self.settings.optimize, + optimize: self.state.settings.optimize, } } diff --git a/wasm/lib/src/browser_module.rs b/wasm/lib/src/browser_module.rs index eefc39552c..a58d3f11ba 100644 --- a/wasm/lib/src/browser_module.rs +++ b/wasm/lib/src/browser_module.rs @@ -1,5 +1,6 @@ use futures::Future; use js_sys::Promise; +use std::sync::Arc; use wasm_bindgen::prelude::*; use wasm_bindgen::JsCast; use wasm_bindgen_futures::{future_to_promise, JsFuture}; @@ -14,6 +15,14 @@ use rustpython_vm::VirtualMachine; use crate::{convert, vm_class::weak_vm, wasm_builtins::window}; +// TODO: Fix this when threading is supported in WASM. +unsafe impl Send for PyPromise {} +unsafe impl Sync for PyPromise {} +unsafe impl Send for Document {} +unsafe impl Sync for Document {} +unsafe impl Send for Element {} +unsafe impl Sync for Element {} + enum FetchResponseFormat { Json, Text, @@ -367,11 +376,12 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { }) } -pub fn setup_browser_module(vm: &VirtualMachine) { - vm.stdlib_inits - .borrow_mut() +pub fn setup_browser_module(vm: &mut VirtualMachine) { + let state = Arc::get_mut(&mut vm.state).unwrap(); + state + .stdlib_inits .insert("_browser".to_owned(), Box::new(make_module)); - vm.frozen.borrow_mut().extend(py_compile_bytecode!( + state.frozen.extend(py_compile_bytecode!( file = "src/browser.py", module_name = "browser", )); diff --git a/wasm/lib/src/js_module.rs b/wasm/lib/src/js_module.rs index 72497c43aa..9e39e7296b 100644 --- a/wasm/lib/src/js_module.rs +++ b/wasm/lib/src/js_module.rs @@ -5,6 +5,7 @@ use rustpython_vm::obj::{objfloat::PyFloatRef, objstr::PyStringRef, objtype::PyC use rustpython_vm::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; use rustpython_vm::types::create_type; use rustpython_vm::VirtualMachine; +use std::sync::Arc; use wasm_bindgen::{prelude::*, JsCast}; #[wasm_bindgen(inline_js = " @@ -34,6 +35,10 @@ pub struct PyJsValue { } type PyJsValueRef = PyRef; +// TODO: Fix this when threading is supported in WASM. +unsafe impl Send for PyJsValue {} +unsafe impl Sync for PyJsValue {} + impl PyValue for PyJsValue { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("_js", "JsValue") @@ -253,8 +258,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { }) } -pub fn setup_js_module(vm: &VirtualMachine) { - vm.stdlib_inits - .borrow_mut() +pub fn setup_js_module(vm: &mut VirtualMachine) { + let state = Arc::get_mut(&mut vm.state).unwrap(); + state + .stdlib_inits .insert("_js".to_owned(), Box::new(make_module)); } diff --git a/wasm/lib/src/vm_class.rs b/wasm/lib/src/vm_class.rs index 199c0b5a5d..afc6bb1a8d 100644 --- a/wasm/lib/src/vm_class.rs +++ b/wasm/lib/src/vm_class.rs @@ -1,6 +1,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::{Rc, Weak}; +use std::sync::Arc; use js_sys::{Object, TypeError}; use wasm_bindgen::prelude::*; @@ -36,9 +37,9 @@ impl StoredVirtualMachine { vm.wasm_id = Some(id); let scope = vm.new_scope_with_builtins(); - js_module::setup_js_module(&vm); + js_module::setup_js_module(&mut vm); if inject_browser_module { - vm.stdlib_inits.borrow_mut().insert( + Arc::get_mut(&mut vm.state).unwrap().stdlib_inits.insert( "_window".to_owned(), Box::new(|vm| { py_module!(vm, "_window", { @@ -46,7 +47,7 @@ impl StoredVirtualMachine { }) }), ); - setup_browser_module(&vm); + setup_browser_module(&mut vm); } vm.initialize(InitParameter::InitializeInternal); @@ -288,26 +289,19 @@ impl WASMVirtualMachine { #[wasm_bindgen(js_name = injectJSModule)] pub fn inject_js_module(&self, name: String, module: Object) -> Result<(), JsValue> { self.with(|StoredVirtualMachine { ref vm, .. }| { - let mut module_items: HashMap = HashMap::new(); + let py_module = vm.new_module(&name, vm.ctx.new_dict()); for entry in convert::object_entries(&module) { let (key, value) = entry?; let key = Object::from(key).to_string(); - module_items.insert(key.into(), convert::js_to_py(vm, value)); + extend_module!(vm, py_module, { + String::from(key) => convert::js_to_py(vm, value), + }); } - let mod_name = name.clone(); - - let stdlib_init_fn = move |vm: &VirtualMachine| { - let module = vm.new_module(&name, vm.ctx.new_dict()); - for (key, value) in module_items.clone() { - vm.set_attr(&module, key, value).unwrap(); - } - module - }; - - vm.stdlib_inits - .borrow_mut() - .insert(mod_name, Box::new(stdlib_init_fn)); + let sys_modules = vm + .get_attribute(vm.sys_module.clone(), "modules") + .to_js(vm)?; + sys_modules.set_item(&name, py_module, vm).to_js(vm)?; Ok(()) })?