diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py index 2a27684324..489eec714e 100644 --- a/Lib/_weakrefset.py +++ b/Lib/_weakrefset.py @@ -80,8 +80,7 @@ def __contains__(self, item): return wr in self.data def __reduce__(self): - return (self.__class__, (list(self),), - getattr(self, '__dict__', None)) + return self.__class__, (list(self),), self.__getstate__() def add(self, item): if self._pending_removals: diff --git a/Lib/logging/__init__.py b/Lib/logging/__init__.py index 28df075dcd..357d127c09 100644 --- a/Lib/logging/__init__.py +++ b/Lib/logging/__init__.py @@ -56,7 +56,7 @@ # #_startTime is used as the base when calculating the relative time of events # -_startTime = time.time() +_startTime = time.time_ns() # #raiseExceptions is used to see if exceptions during handling should be @@ -159,12 +159,9 @@ def addLevelName(level, levelName): This is used when converting levels to text during message formatting. """ - _acquireLock() - try: #unlikely to cause an exception, but you never know... + with _lock: _levelToName[level] = levelName _nameToLevel[levelName] = level - finally: - _releaseLock() if hasattr(sys, "_getframe"): currentframe = lambda: sys._getframe(1) @@ -201,7 +198,7 @@ def _is_internal_frame(frame): """Signal whether the frame is a CPython or logging module internal.""" filename = os.path.normcase(frame.f_code.co_filename) return filename == _srcfile or ( - "importlib" in filename and "_bootstrap" in filename + "importlib" in filename and "_bootstrap" in filename ) @@ -231,21 +228,27 @@ def _checkLevel(level): # _lock = threading.RLock() -def _acquireLock(): +def _prepareFork(): """ - Acquire the module-level lock for serializing access to shared data. + Prepare to fork a new child process by acquiring the module-level lock. - This should be released with _releaseLock(). + This should be used in conjunction with _afterFork(). """ - if _lock: + # Wrap the lock acquisition in a try-except to prevent the lock from being + # abandoned in the event of an asynchronous exception. See gh-106238. + try: _lock.acquire() + except BaseException: + _lock.release() + raise -def _releaseLock(): +def _afterFork(): """ - Release the module-level lock acquired by calling _acquireLock(). + After a new child process has been forked, release the module-level lock. + + This should be used in conjunction with _prepareFork(). """ - if _lock: - _lock.release() + _lock.release() # Prevent a held logging lock from blocking a child from logging. @@ -260,23 +263,20 @@ def _register_at_fork_reinit_lock(instance): _at_fork_reinit_lock_weakset = weakref.WeakSet() def _register_at_fork_reinit_lock(instance): - _acquireLock() - try: + with _lock: _at_fork_reinit_lock_weakset.add(instance) - finally: - _releaseLock() def _after_at_fork_child_reinit_locks(): for handler in _at_fork_reinit_lock_weakset: handler._at_fork_reinit() - # _acquireLock() was called in the parent before forking. + # _prepareFork() was called in the parent before forking. # The lock is reinitialized to unlocked state. _lock._at_fork_reinit() - os.register_at_fork(before=_acquireLock, + os.register_at_fork(before=_prepareFork, after_in_child=_after_at_fork_child_reinit_locks, - after_in_parent=_releaseLock) + after_in_parent=_afterFork) #--------------------------------------------------------------------------- @@ -300,7 +300,7 @@ def __init__(self, name, level, pathname, lineno, """ Initialize a logging record with interesting information. """ - ct = time.time() + ct = time.time_ns() self.name = name self.msg = msg # @@ -322,7 +322,7 @@ def __init__(self, name, level, pathname, lineno, # Thus, while not removing the isinstance check, it does now look # for collections.abc.Mapping rather than, as before, dict. if (args and len(args) == 1 and isinstance(args[0], collections.abc.Mapping) - and args[0]): + and args[0]): args = args[0] self.args = args self.levelname = getLevelName(level) @@ -339,9 +339,17 @@ def __init__(self, name, level, pathname, lineno, self.stack_info = sinfo self.lineno = lineno self.funcName = func - self.created = ct - self.msecs = int((ct - int(ct)) * 1000) + 0.0 # see gh-89047 - self.relativeCreated = (self.created - _startTime) * 1000 + self.created = ct / 1e9 # ns to float seconds + # Get the number of whole milliseconds (0-999) in the fractional part of seconds. + # Eg: 1_677_903_920_999_998_503 ns --> 999_998_503 ns--> 999 ms + # Convert to float by adding 0.0 for historical reasons. See gh-89047 + self.msecs = (ct % 1_000_000_000) // 1_000_000 + 0.0 + if self.msecs == 999.0 and int(self.created) != ct // 1_000_000_000: + # ns -> sec conversion can round up, e.g: + # 1_677_903_920_999_999_900 ns --> 1_677_903_921.0 sec + self.msecs = 0.0 + + self.relativeCreated = (ct - _startTime) / 1e6 if logThreads: self.thread = threading.get_ident() self.threadName = threading.current_thread().name @@ -378,7 +386,7 @@ def __init__(self, name, level, pathname, lineno, def __repr__(self): return ''%(self.name, self.levelno, - self.pathname, self.lineno, self.msg) + self.pathname, self.lineno, self.msg) def getMessage(self): """ @@ -572,7 +580,7 @@ class Formatter(object): %(lineno)d Source line number where the logging call was issued (if available) %(funcName)s Function name - %(created)f Time when the LogRecord was created (time.time() + %(created)f Time when the LogRecord was created (time.time_ns() / 1e9 return value) %(asctime)s Textual time when the LogRecord was created %(msecs)d Millisecond portion of the creation time @@ -583,6 +591,7 @@ class Formatter(object): %(threadName)s Thread name (if available) %(taskName)s Task name (if available) %(process)d Process ID (if available) + %(processName)s Process name (if available) %(message)s The result of record.getMessage(), computed just as the record is emitted """ @@ -608,7 +617,7 @@ def __init__(self, fmt=None, datefmt=None, style='%', validate=True, *, """ if style not in _STYLES: raise ValueError('Style must be one of: %s' % ','.join( - _STYLES.keys())) + _STYLES.keys())) self._style = _STYLES[style][0](fmt, defaults=defaults) if validate: self._style.validate() @@ -658,7 +667,7 @@ def formatException(self, ei): # See issues #9427, #1553375. Commented out for now. #if getattr(self, 'fullstack', False): # traceback.print_stack(tb.tb_frame.f_back, file=sio) - traceback.print_exception(ei[0], ei[1], tb, None, sio) + traceback.print_exception(ei[0], ei[1], tb, limit=None, file=sio) s = sio.getvalue() sio.close() if s[-1:] == "\n": @@ -879,25 +888,20 @@ def _removeHandlerRef(wr): # set to None. It can also be called from another thread. So we need to # pre-emptively grab the necessary globals and check if they're None, # to prevent race conditions and failures during interpreter shutdown. - acquire, release, handlers = _acquireLock, _releaseLock, _handlerList - if acquire and release and handlers: - acquire() - try: - handlers.remove(wr) - except ValueError: - pass - finally: - release() + handlers, lock = _handlerList, _lock + if lock and handlers: + with lock: + try: + handlers.remove(wr) + except ValueError: + pass def _addHandlerRef(handler): """ Add a handler to the internal cleanup list using a weak reference. """ - _acquireLock() - try: + with _lock: _handlerList.append(weakref.ref(handler, _removeHandlerRef)) - finally: - _releaseLock() def getHandlerByName(name): @@ -912,8 +916,7 @@ def getHandlerNames(): """ Return all known handler names as an immutable set. """ - result = set(_handlers.keys()) - return frozenset(result) + return frozenset(_handlers) class Handler(Filterer): @@ -943,15 +946,12 @@ def get_name(self): return self._name def set_name(self, name): - _acquireLock() - try: + with _lock: if self._name in _handlers: del _handlers[self._name] self._name = name if name: _handlers[name] = self - finally: - _releaseLock() name = property(get_name, set_name) @@ -1023,11 +1023,8 @@ def handle(self, record): if isinstance(rv, LogRecord): record = rv if rv: - self.acquire() - try: + with self.lock: self.emit(record) - finally: - self.release() return rv def setFormatter(self, fmt): @@ -1055,13 +1052,10 @@ def close(self): methods. """ #get the module data lock, as we're updating a shared structure. - _acquireLock() - try: #unlikely to raise an exception, but you never know... + with _lock: self._closed = True if self._name and self._name in _handlers: del _handlers[self._name] - finally: - _releaseLock() def handleError(self, record): """ @@ -1076,14 +1070,14 @@ def handleError(self, record): The record which was being processed is passed in to this method. """ if raiseExceptions and sys.stderr: # see issue 13807 - t, v, tb = sys.exc_info() + exc = sys.exception() try: sys.stderr.write('--- Logging error ---\n') - traceback.print_exception(t, v, tb, None, sys.stderr) + traceback.print_exception(exc, limit=None, file=sys.stderr) sys.stderr.write('Call stack:\n') # Walk the stack frame up until we're out of logging, # so as to print the calling context. - frame = tb.tb_frame + frame = exc.__traceback__.tb_frame while (frame and os.path.dirname(frame.f_code.co_filename) == __path__[0]): frame = frame.f_back @@ -1092,7 +1086,7 @@ def handleError(self, record): else: # couldn't find the right stack frame, for some reason sys.stderr.write('Logged from file %s, line %s\n' % ( - record.filename, record.lineno)) + record.filename, record.lineno)) # Issue 18671: output logging message and arguments try: sys.stderr.write('Message: %r\n' @@ -1104,11 +1098,11 @@ def handleError(self, record): sys.stderr.write('Unable to print the message and arguments' ' - possible formatting error.\nUse the' ' traceback above to help find the error.\n' - ) + ) except OSError: #pragma: no cover pass # see issue 5971 finally: - del t, v, tb + del exc def __repr__(self): level = getLevelName(self.level) @@ -1138,12 +1132,9 @@ def flush(self): """ Flushes the stream. """ - self.acquire() - try: + with self.lock: if self.stream and hasattr(self.stream, "flush"): self.stream.flush() - finally: - self.release() def emit(self, record): """ @@ -1179,12 +1170,9 @@ def setStream(self, stream): result = None else: result = self.stream - self.acquire() - try: + with self.lock: self.flush() self.stream = stream - finally: - self.release() return result def __repr__(self): @@ -1234,8 +1222,7 @@ def close(self): """ Closes the stream. """ - self.acquire() - try: + with self.lock: try: if self.stream: try: @@ -1251,8 +1238,6 @@ def close(self): # Also see Issue #42378: we also rely on # self._closed being set to True there StreamHandler.close(self) - finally: - self.release() def _open(self): """ @@ -1388,8 +1373,7 @@ def getLogger(self, name): rv = None if not isinstance(name, str): raise TypeError('A logger name must be a string') - _acquireLock() - try: + with _lock: if name in self.loggerDict: rv = self.loggerDict[name] if isinstance(rv, PlaceHolder): @@ -1404,8 +1388,6 @@ def getLogger(self, name): rv.manager = self self.loggerDict[name] = rv self._fixupParents(rv) - finally: - _releaseLock() return rv def setLoggerClass(self, klass): @@ -1468,12 +1450,11 @@ def _clear_cache(self): Called when level changes are made """ - _acquireLock() - for logger in self.loggerDict.values(): - if isinstance(logger, Logger): - logger._cache.clear() - self.root._cache.clear() - _releaseLock() + with _lock: + for logger in self.loggerDict.values(): + if isinstance(logger, Logger): + logger._cache.clear() + self.root._cache.clear() #--------------------------------------------------------------------------- # Logger classes and functions @@ -1494,6 +1475,8 @@ class Logger(Filterer): level, and "input.csv", "input.xls" and "input.gnu" for the sub-levels. There is no arbitrary limit to the depth of nesting. """ + _tls = threading.local() + def __init__(self, name, level=NOTSET): """ Initialize the logger with a name and an optional level. @@ -1552,7 +1535,7 @@ def warning(self, msg, *args, **kwargs): def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) self.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): @@ -1649,7 +1632,7 @@ def makeRecord(self, name, level, fn, lno, msg, args, exc_info, specialized LogRecords. """ rv = _logRecordFactory(name, level, fn, lno, msg, args, exc_info, func, - sinfo) + sinfo) if extra is not None: for key in extra: if (key in ["message", "asctime"]) or (key in rv.__dict__): @@ -1690,36 +1673,35 @@ def handle(self, record): This method is used for unpickled records received from a socket, as well as those created locally. Logger-level filtering is applied. """ - if self.disabled: - return - maybe_record = self.filter(record) - if not maybe_record: + if self._is_disabled(): return - if isinstance(maybe_record, LogRecord): - record = maybe_record - self.callHandlers(record) + + self._tls.in_progress = True + try: + maybe_record = self.filter(record) + if not maybe_record: + return + if isinstance(maybe_record, LogRecord): + record = maybe_record + self.callHandlers(record) + finally: + self._tls.in_progress = False def addHandler(self, hdlr): """ Add the specified handler to this logger. """ - _acquireLock() - try: + with _lock: if not (hdlr in self.handlers): self.handlers.append(hdlr) - finally: - _releaseLock() def removeHandler(self, hdlr): """ Remove the specified handler from this logger. """ - _acquireLock() - try: + with _lock: if hdlr in self.handlers: self.handlers.remove(hdlr) - finally: - _releaseLock() def hasHandlers(self): """ @@ -1791,22 +1773,19 @@ def isEnabledFor(self, level): """ Is this logger enabled for level 'level'? """ - if self.disabled: + if self._is_disabled(): return False try: return self._cache[level] except KeyError: - _acquireLock() - try: + with _lock: if self.manager.disable >= level: is_enabled = self._cache[level] = False else: is_enabled = self._cache[level] = ( - level >= self.getEffectiveLevel() + level >= self.getEffectiveLevel() ) - finally: - _releaseLock() return is_enabled def getChild(self, suffix): @@ -1836,16 +1815,18 @@ def _hierlevel(logger): return 1 + logger.name.count('.') d = self.manager.loggerDict - _acquireLock() - try: + with _lock: # exclude PlaceHolders - the last check is to ensure that lower-level # descendants aren't returned - if there are placeholders, a logger's # parent field might point to a grandparent or ancestor thereof. return set(item for item in d.values() if isinstance(item, Logger) and item.parent is self and _hierlevel(item) == 1 + _hierlevel(item.parent)) - finally: - _releaseLock() + + def _is_disabled(self): + # We need to use getattr as it will only be set the first time a log + # message is recorded on any given thread + return self.disabled or getattr(self._tls, 'in_progress', False) def __repr__(self): level = getLevelName(self.getEffectiveLevel()) @@ -1881,7 +1862,7 @@ class LoggerAdapter(object): information in logging output. """ - def __init__(self, logger, extra=None): + def __init__(self, logger, extra=None, merge_extra=False): """ Initialize the adapter with a logger and a dict-like object which provides contextual information. This constructor signature allows @@ -1891,9 +1872,20 @@ def __init__(self, logger, extra=None): following example: adapter = LoggerAdapter(someLogger, dict(p1=v1, p2="v2")) + + By default, LoggerAdapter objects will drop the "extra" argument + passed on the individual log calls to use its own instead. + + Initializing it with merge_extra=True will instead merge both + maps when logging, the individual call extra taking precedence + over the LoggerAdapter instance extra + + .. versionchanged:: 3.13 + The *merge_extra* argument was added. """ self.logger = logger self.extra = extra + self.merge_extra = merge_extra def process(self, msg, kwargs): """ @@ -1905,7 +1897,10 @@ def process(self, msg, kwargs): Normally, you'll only need to override this one method in a LoggerAdapter subclass for your specific needs. """ - kwargs["extra"] = self.extra + if self.merge_extra and "extra" in kwargs: + kwargs["extra"] = {**self.extra, **kwargs["extra"]} + else: + kwargs["extra"] = self.extra return msg, kwargs # @@ -1931,7 +1926,7 @@ def warning(self, msg, *args, **kwargs): def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) self.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): @@ -2088,8 +2083,7 @@ def basicConfig(**kwargs): """ # Add thread safety in case someone mistakenly calls # basicConfig() from multiple threads - _acquireLock() - try: + with _lock: force = kwargs.pop('force', False) encoding = kwargs.pop('encoding', None) errors = kwargs.pop('errors', 'backslashreplace') @@ -2125,7 +2119,7 @@ def basicConfig(**kwargs): style = kwargs.pop("style", '%') if style not in _STYLES: raise ValueError('Style must be one of: %s' % ','.join( - _STYLES.keys())) + _STYLES.keys())) fs = kwargs.pop("format", _STYLES[style][1]) fmt = Formatter(fs, dfs, style) for h in handlers: @@ -2138,8 +2132,6 @@ def basicConfig(**kwargs): if kwargs: keys = ', '.join(kwargs.keys()) raise ValueError('Unrecognised argument(s): %s' % keys) - finally: - _releaseLock() #--------------------------------------------------------------------------- # Utility functions at module level. @@ -2202,7 +2194,7 @@ def warning(msg, *args, **kwargs): def warn(msg, *args, **kwargs): warnings.warn("The 'warn' function is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) warning(msg, *args, **kwargs) def info(msg, *args, **kwargs): diff --git a/Lib/logging/config.py b/Lib/logging/config.py index ef04a35168..190b4f9225 100644 --- a/Lib/logging/config.py +++ b/Lib/logging/config.py @@ -83,15 +83,12 @@ def fileConfig(fname, defaults=None, disable_existing_loggers=True, encoding=Non formatters = _create_formatters(cp) # critical section - logging._acquireLock() - try: + with logging._lock: _clearExistingHandlers() # Handlers add themselves to logging._handlers handlers = _install_handlers(cp, formatters) _install_loggers(cp, handlers, disable_existing_loggers) - finally: - logging._releaseLock() def _resolve(name): @@ -314,7 +311,7 @@ def convert_with_key(self, key, value, replace=True): if replace: self[key] = result if type(result) in (ConvertingDict, ConvertingList, - ConvertingTuple): + ConvertingTuple): result.parent = self result.key = key return result @@ -323,7 +320,7 @@ def convert(self, value): result = self.configurator.convert(value) if value is not result: if type(result) in (ConvertingDict, ConvertingList, - ConvertingTuple): + ConvertingTuple): result.parent = self return result @@ -378,7 +375,7 @@ class BaseConfigurator(object): WORD_PATTERN = re.compile(r'^\s*(\w+)\s*') DOT_PATTERN = re.compile(r'^\.\s*(\w+)\s*') - INDEX_PATTERN = re.compile(r'^\[\s*(\w+)\s*\]\s*') + INDEX_PATTERN = re.compile(r'^\[([^\[\]]*)\]\s*') DIGIT_PATTERN = re.compile(r'^\d+$') value_converters = { @@ -464,8 +461,8 @@ def convert(self, value): elif not isinstance(value, ConvertingList) and isinstance(value, list): value = ConvertingList(value) value.configurator = self - elif not isinstance(value, ConvertingTuple) and \ - isinstance(value, tuple) and not hasattr(value, '_fields'): + elif not isinstance(value, ConvertingTuple) and\ + isinstance(value, tuple) and not hasattr(value, '_fields'): value = ConvertingTuple(value) value.configurator = self elif isinstance(value, str): # str for py3k @@ -543,8 +540,7 @@ def configure(self): raise ValueError("Unsupported version: %s" % config['version']) incremental = config.pop('incremental', False) EMPTY_DICT = {} - logging._acquireLock() - try: + with logging._lock: if incremental: handlers = config.get('handlers', EMPTY_DICT) for name in handlers: @@ -585,7 +581,7 @@ def configure(self): for name in formatters: try: formatters[name] = self.configure_formatter( - formatters[name]) + formatters[name]) except Exception as e: raise ValueError('Unable to configure ' 'formatter %r' % name) from e @@ -688,8 +684,6 @@ def configure(self): except Exception as e: raise ValueError('Unable to configure root ' 'logger') from e - finally: - logging._releaseLock() def configure_formatter(self, config): """Configure a formatter from a dictionary.""" @@ -700,10 +694,9 @@ def configure_formatter(self, config): except TypeError as te: if "'format'" not in str(te): raise - #Name of parameter changed from fmt to format. - #Retry with old name. - #This is so that code can be used with older Python versions - #(e.g. by Django) + # logging.Formatter and its subclasses expect the `fmt` + # parameter instead of `format`. Retry passing configuration + # with `fmt`. config['fmt'] = config.pop('format') config['()'] = factory result = self.configure_custom(config) @@ -812,7 +805,7 @@ def configure_handler(self, config): elif issubclass(klass, logging.handlers.QueueHandler): # Another special case for handler which refers to other handlers # if 'handlers' not in config: - # raise ValueError('No handlers specified for a QueueHandler') + # raise ValueError('No handlers specified for a QueueHandler') if 'queue' in config: qspec = config['queue'] @@ -836,8 +829,8 @@ def configure_handler(self, config): else: if isinstance(lspec, str): listener = self.resolve(lspec) - if isinstance(listener, type) and \ - not issubclass(listener, logging.handlers.QueueListener): + if isinstance(listener, type) and\ + not issubclass(listener, logging.handlers.QueueListener): raise TypeError('Invalid listener specifier %r' % lspec) elif isinstance(lspec, dict): if '()' not in lspec: @@ -861,11 +854,11 @@ def configure_handler(self, config): except Exception as e: raise ValueError('Unable to set required handler %r' % hn) from e config['handlers'] = hlist - elif issubclass(klass, logging.handlers.SMTPHandler) and \ - 'mailhost' in config: + elif issubclass(klass, logging.handlers.SMTPHandler) and\ + 'mailhost' in config: config['mailhost'] = self.as_tuple(config['mailhost']) - elif issubclass(klass, logging.handlers.SysLogHandler) and \ - 'address' in config: + elif issubclass(klass, logging.handlers.SysLogHandler) and\ + 'address' in config: config['address'] = self.as_tuple(config['address']) if issubclass(klass, logging.handlers.QueueHandler): factory = functools.partial(self._configure_queue_handler, klass) @@ -1018,9 +1011,8 @@ class ConfigSocketReceiver(ThreadingTCPServer): def __init__(self, host='localhost', port=DEFAULT_LOGGING_CONFIG_PORT, handler=None, ready=None, verify=None): ThreadingTCPServer.__init__(self, (host, port), handler) - logging._acquireLock() - self.abort = 0 - logging._releaseLock() + with logging._lock: + self.abort = 0 self.timeout = 1 self.ready = ready self.verify = verify @@ -1034,9 +1026,8 @@ def serve_until_stopped(self): self.timeout) if rd: self.handle_request() - logging._acquireLock() - abort = self.abort - logging._releaseLock() + with logging._lock: + abort = self.abort self.server_close() class Server(threading.Thread): @@ -1057,9 +1048,8 @@ def run(self): self.port = server.server_address[1] self.ready.set() global _listener - logging._acquireLock() - _listener = server - logging._releaseLock() + with logging._lock: + _listener = server server.serve_until_stopped() return Server(ConfigSocketReceiver, ConfigStreamHandler, port, verify) @@ -1069,10 +1059,7 @@ def stopListening(): Stop the listening server which was created with a call to listen(). """ global _listener - logging._acquireLock() - try: + with logging._lock: if _listener: _listener.abort = 1 _listener = None - finally: - logging._releaseLock() diff --git a/Lib/logging/handlers.py b/Lib/logging/handlers.py index bf42ea1103..d3ea06c731 100644 --- a/Lib/logging/handlers.py +++ b/Lib/logging/handlers.py @@ -23,11 +23,17 @@ To use, simply 'import logging.handlers' and log away! """ -import io, logging, socket, os, pickle, struct, time, re -from stat import ST_DEV, ST_INO, ST_MTIME +import copy +import io +import logging +import os +import pickle import queue +import re +import socket +import struct import threading -import copy +import time # # Some constants... @@ -272,7 +278,7 @@ def __init__(self, filename, when='h', interval=1, backupCount=0, # path object (see Issue #27493), but self.baseFilename will be a string filename = self.baseFilename if os.path.exists(filename): - t = os.stat(filename)[ST_MTIME] + t = int(os.stat(filename).st_mtime) else: t = int(time.time()) self.rolloverAt = self.computeRollover(t) @@ -304,10 +310,10 @@ def computeRollover(self, currentTime): rotate_ts = _MIDNIGHT else: rotate_ts = ((self.atTime.hour * 60 + self.atTime.minute)*60 + - self.atTime.second) + self.atTime.second) r = rotate_ts - ((currentHour * 60 + currentMinute) * 60 + - currentSecond) + currentSecond) if r <= 0: # Rotate time is before the current time (for example when # self.rotateAt is 13:45 and it now 14:15), rotation is @@ -465,8 +471,7 @@ class WatchedFileHandler(logging.FileHandler): This handler is not appropriate for use under Windows, because under Windows open files cannot be moved or renamed - logging opens the files with exclusive locks - and so there is no need - for such a handler. Furthermore, ST_INO is not supported under - Windows; stat always returns zero for this value. + for such a handler. This handler is based on a suggestion and patch by Chad J. Schroeder. @@ -482,9 +487,11 @@ def __init__(self, filename, mode='a', encoding=None, delay=False, self._statstream() def _statstream(self): - if self.stream: - sres = os.fstat(self.stream.fileno()) - self.dev, self.ino = sres[ST_DEV], sres[ST_INO] + if self.stream is None: + return + sres = os.fstat(self.stream.fileno()) + self.dev = sres.st_dev + self.ino = sres.st_ino def reopenIfNeeded(self): """ @@ -494,6 +501,9 @@ def reopenIfNeeded(self): has, close the old stream and reopen the file to get the current stream. """ + if self.stream is None: + return + # Reduce the chance of race conditions by stat'ing by path only # once and then fstat'ing our new fd if we opened a new log stream. # See issue #14632: Thanks to John Mulligan for the problem report @@ -501,18 +511,23 @@ def reopenIfNeeded(self): try: # stat the file by path, checking for existence sres = os.stat(self.baseFilename) + + # compare file system stat with that of our stream file handle + reopen = (sres.st_dev != self.dev or sres.st_ino != self.ino) except FileNotFoundError: - sres = None - # compare file system stat with that of our stream file handle - if not sres or sres[ST_DEV] != self.dev or sres[ST_INO] != self.ino: - if self.stream is not None: - # we have an open file handle, clean it up - self.stream.flush() - self.stream.close() - self.stream = None # See Issue #21742: _open () might fail. - # open a new file handle and get new stat info from that fd - self.stream = self._open() - self._statstream() + reopen = True + + if not reopen: + return + + # we have an open file handle, clean it up + self.stream.flush() + self.stream.close() + self.stream = None # See Issue #21742: _open () might fail. + + # open a new file handle and get new stat info from that fd + self.stream = self._open() + self._statstream() def emit(self, record): """ @@ -682,15 +697,12 @@ def close(self): """ Closes the socket. """ - self.acquire() - try: + with self.lock: sock = self.sock if sock: self.sock = None sock.close() logging.Handler.close(self) - finally: - self.release() class DatagramHandler(SocketHandler): """ @@ -803,7 +815,7 @@ class SysLogHandler(logging.Handler): "panic": LOG_EMERG, # DEPRECATED "warn": LOG_WARNING, # DEPRECATED "warning": LOG_WARNING, - } + } facility_names = { "auth": LOG_AUTH, @@ -830,7 +842,7 @@ class SysLogHandler(logging.Handler): "local5": LOG_LOCAL5, "local6": LOG_LOCAL6, "local7": LOG_LOCAL7, - } + } # Originally added to work around GH-43683. Unnecessary since GH-50043 but kept # for backwards compatibility. @@ -950,15 +962,12 @@ def close(self): """ Closes the socket. """ - self.acquire() - try: + with self.lock: sock = self.socket if sock: self.socket = None sock.close() logging.Handler.close(self) - finally: - self.release() def mapPriority(self, levelName): """ @@ -1031,7 +1040,8 @@ def __init__(self, mailhost, fromaddr, toaddrs, subject, only be used when authentication credentials are supplied. The tuple will be either an empty tuple, or a single-value tuple with the name of a keyfile, or a 2-value tuple with the names of the keyfile and - certificate file. (This tuple is passed to the `starttls` method). + certificate file. (This tuple is passed to the + `ssl.SSLContext.load_cert_chain` method). A timeout in seconds can be specified for the SMTP connection (the default is one second). """ @@ -1084,8 +1094,23 @@ def emit(self, record): msg.set_content(self.format(record)) if self.username: if self.secure is not None: + import ssl + + try: + keyfile = self.secure[0] + except IndexError: + keyfile = None + + try: + certfile = self.secure[1] + except IndexError: + certfile = None + + context = ssl._create_stdlib_context( + certfile=certfile, keyfile=keyfile + ) smtp.ehlo() - smtp.starttls(*self.secure) + smtp.starttls(context=context) smtp.ehlo() smtp.login(self.username, self.password) smtp.send_message(msg) @@ -1132,10 +1157,10 @@ def __init__(self, appname, dllname=None, logtype="Application"): logging.WARNING : win32evtlog.EVENTLOG_WARNING_TYPE, logging.ERROR : win32evtlog.EVENTLOG_ERROR_TYPE, logging.CRITICAL: win32evtlog.EVENTLOG_ERROR_TYPE, - } + } except ImportError: - print("The Python Win32 extensions for NT (service, event " \ - "logging) appear not to be available.") + print("The Python Win32 extensions for NT (service, event "\ + "logging) appear not to be available.") self._welu = None def getMessageID(self, record): @@ -1330,11 +1355,8 @@ def flush(self): This version just zaps the buffer to empty. """ - self.acquire() - try: + with self.lock: self.buffer.clear() - finally: - self.release() def close(self): """ @@ -1378,17 +1400,14 @@ def shouldFlush(self, record): Check for buffer full or a record at the flushLevel or higher. """ return (len(self.buffer) >= self.capacity) or \ - (record.levelno >= self.flushLevel) + (record.levelno >= self.flushLevel) def setTarget(self, target): """ Set the target handler for this handler. """ - self.acquire() - try: + with self.lock: self.target = target - finally: - self.release() def flush(self): """ @@ -1398,14 +1417,11 @@ def flush(self): The record buffer is only cleared if a target has been set. """ - self.acquire() - try: + with self.lock: if self.target: for record in self.buffer: self.target.handle(record) self.buffer.clear() - finally: - self.release() def close(self): """ @@ -1416,12 +1432,9 @@ def close(self): if self.flushOnClose: self.flush() finally: - self.acquire() - try: + with self.lock: self.target = None BufferingHandler.close(self) - finally: - self.release() class QueueHandler(logging.Handler): @@ -1532,6 +1545,9 @@ def start(self): This starts up a background thread to monitor the queue for LogRecords to process. """ + if self._thread is not None: + raise RuntimeError("Listener already started") + self._thread = t = threading.Thread(target=self._monitor) t.daemon = True t.start() @@ -1603,6 +1619,7 @@ def stop(self): Note that if you don't call this before your application exits, there may be some records still left on the queue, which won't be processed. """ - self.enqueue_sentinel() - self._thread.join() - self._thread = None + if self._thread: # see gh-114706 - allow calling this more than once + self.enqueue_sentinel() + self._thread.join() + self._thread = None diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py index d0582e3cd5..8caddd204d 100644 --- a/Lib/multiprocessing/connection.py +++ b/Lib/multiprocessing/connection.py @@ -19,7 +19,6 @@ import tempfile import itertools -import _multiprocessing from . import util @@ -28,6 +27,7 @@ _ForkingPickler = reduction.ForkingPickler try: + import _multiprocessing import _winapi from _winapi import WAIT_OBJECT_0, WAIT_ABANDONED_0, WAIT_TIMEOUT, INFINITE except ImportError: @@ -846,7 +846,7 @@ def PipeClient(address): _LEGACY_LENGTHS = (_MD5ONLY_MESSAGE_LENGTH, _MD5_DIGEST_LEN) -def _get_digest_name_and_payload(message: bytes) -> (str, bytes): +def _get_digest_name_and_payload(message): # type: (bytes) -> tuple[str, bytes] """Returns a digest name and the payload for a response hash. If a legacy protocol is detected based on the message length @@ -956,7 +956,7 @@ def answer_challenge(connection, authkey: bytes): f'Protocol error, expected challenge: {message=}') message = message[len(_CHALLENGE):] if len(message) < _MD5ONLY_MESSAGE_LENGTH: - raise AuthenticationError('challenge too short: {len(message)} bytes') + raise AuthenticationError(f'challenge too short: {len(message)} bytes') digest = _create_response(authkey, message) connection.send_bytes(digest) response = connection.recv_bytes(256) # reject large message @@ -1012,8 +1012,20 @@ def _exhaustive_wait(handles, timeout): # returning the first signalled might create starvation issues.) L = list(handles) ready = [] + # Windows limits WaitForMultipleObjects at 64 handles, and we use a + # few for synchronisation, so we switch to batched waits at 60. + if len(L) > 60: + try: + res = _winapi.BatchedWaitForMultipleObjects(L, False, timeout) + except TimeoutError: + return [] + ready.extend(L[i] for i in res) + if res: + L = [h for i, h in enumerate(L) if i > res[0] & i not in res] + timeout = 0 while L: - res = _winapi.WaitForMultipleObjects(L, False, timeout) + short_L = L[:60] if len(L) > 60 else L + res = _winapi.WaitForMultipleObjects(short_L, False, timeout) if res == WAIT_TIMEOUT: break elif WAIT_OBJECT_0 <= res < WAIT_OBJECT_0 + len(L): diff --git a/Lib/multiprocessing/context.py b/Lib/multiprocessing/context.py index de8a264829..f395e8b04d 100644 --- a/Lib/multiprocessing/context.py +++ b/Lib/multiprocessing/context.py @@ -145,7 +145,7 @@ def freeze_support(self): '''Check whether this is a fake forked process in a frozen executable. If so then run code specified by commandline and exit. ''' - if sys.platform == 'win32' and getattr(sys, 'frozen', False): + if self.get_start_method() == 'spawn' and getattr(sys, 'frozen', False): from .spawn import freeze_support freeze_support() diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index 4642707dae..bff7fb91d9 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -1,3 +1,4 @@ +import atexit import errno import os import selectors @@ -167,6 +168,8 @@ def ensure_running(self): def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): '''Run forkserver.''' if preload: + if sys_path is not None: + sys.path[:] = sys_path if '__main__' in preload and main_path is not None: process.current_process()._inheriting = True try: @@ -271,6 +274,8 @@ def sigchld_handler(*_unused): selector.close() unused_fds = [alive_r, child_w, sig_r, sig_w] unused_fds.extend(pid_to_fd.values()) + atexit._clear() + atexit.register(util._exit_function) code = _serve_one(child_r, fds, unused_fds, old_handlers) @@ -278,6 +283,7 @@ def sigchld_handler(*_unused): sys.excepthook(*sys.exc_info()) sys.stderr.flush() finally: + atexit._run_exitfuncs() os._exit(code) else: # Send pid to client process diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py index 75d9c18c20..ef791c2751 100644 --- a/Lib/multiprocessing/managers.py +++ b/Lib/multiprocessing/managers.py @@ -90,7 +90,10 @@ def dispatch(c, id, methodname, args=(), kwds={}): kind, result = c.recv() if kind == '#RETURN': return result - raise convert_to_error(kind, result) + try: + raise convert_to_error(kind, result) + finally: + del result # break reference cycle def convert_to_error(kind, result): if kind == '#ERROR': @@ -755,22 +758,29 @@ class BaseProxy(object): _address_to_local = {} _mutex = util.ForkAwareThreadLock() + # Each instance gets a `_serial` number. Unlike `id(...)`, this number + # is never reused. + _next_serial = 1 + def __init__(self, token, serializer, manager=None, authkey=None, exposed=None, incref=True, manager_owned=False): with BaseProxy._mutex: - tls_idset = BaseProxy._address_to_local.get(token.address, None) - if tls_idset is None: - tls_idset = util.ForkAwareLocal(), ProcessLocalSet() - BaseProxy._address_to_local[token.address] = tls_idset + tls_serials = BaseProxy._address_to_local.get(token.address, None) + if tls_serials is None: + tls_serials = util.ForkAwareLocal(), ProcessLocalSet() + BaseProxy._address_to_local[token.address] = tls_serials + + self._serial = BaseProxy._next_serial + BaseProxy._next_serial += 1 # self._tls is used to record the connection used by this # thread to communicate with the manager at token.address - self._tls = tls_idset[0] + self._tls = tls_serials[0] - # self._idset is used to record the identities of all shared - # objects for which the current process owns references and + # self._all_serials is a set used to record the identities of all + # shared objects for which the current process owns references and # which are in the manager at token.address - self._idset = tls_idset[1] + self._all_serials = tls_serials[1] self._token = token self._id = self._token.id @@ -833,7 +843,10 @@ def _callmethod(self, methodname, args=(), kwds={}): conn = self._Client(token.address, authkey=self._authkey) dispatch(conn, None, 'decref', (token.id,)) return proxy - raise convert_to_error(kind, result) + try: + raise convert_to_error(kind, result) + finally: + del result # break reference cycle def _getvalue(self): ''' @@ -850,20 +863,20 @@ def _incref(self): dispatch(conn, None, 'incref', (self._id,)) util.debug('INCREF %r', self._token.id) - self._idset.add(self._id) + self._all_serials.add(self._serial) state = self._manager and self._manager._state self._close = util.Finalize( self, BaseProxy._decref, - args=(self._token, self._authkey, state, - self._tls, self._idset, self._Client), + args=(self._token, self._serial, self._authkey, state, + self._tls, self._all_serials, self._Client), exitpriority=10 ) @staticmethod - def _decref(token, authkey, state, tls, idset, _Client): - idset.discard(token.id) + def _decref(token, serial, authkey, state, tls, idset, _Client): + idset.discard(serial) # check whether manager is still alive if state is None or state.value == State.STARTED: @@ -1159,15 +1172,19 @@ def __imul__(self, value): self._callmethod('__imul__', (value,)) return self + __class_getitem__ = classmethod(types.GenericAlias) -DictProxy = MakeProxyType('DictProxy', ( + +_BaseDictProxy = MakeProxyType('DictProxy', ( '__contains__', '__delitem__', '__getitem__', '__iter__', '__len__', '__setitem__', 'clear', 'copy', 'get', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values' )) -DictProxy._method_to_typeid_ = { +_BaseDictProxy._method_to_typeid_ = { '__iter__': 'Iterator', } +class DictProxy(_BaseDictProxy): + __class_getitem__ = classmethod(types.GenericAlias) ArrayProxy = MakeProxyType('ArrayProxy', ( diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 4f5d88cb97..f979890170 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -200,7 +200,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), self._initargs = initargs if processes is None: - processes = os.cpu_count() or 1 + processes = os.process_cpu_count() or 1 if processes < 1: raise ValueError("Number of processes must be at least 1") if maxtasksperchild is not None: diff --git a/Lib/multiprocessing/popen_fork.py b/Lib/multiprocessing/popen_fork.py index 625981cf47..a57ef6bdad 100644 --- a/Lib/multiprocessing/popen_fork.py +++ b/Lib/multiprocessing/popen_fork.py @@ -1,3 +1,4 @@ +import atexit import os import signal @@ -66,10 +67,13 @@ def _launch(self, process_obj): self.pid = os.fork() if self.pid == 0: try: + atexit._clear() + atexit.register(util._exit_function) os.close(parent_r) os.close(parent_w) code = process_obj._bootstrap(parent_sentinel=child_r) finally: + atexit._run_exitfuncs() os._exit(code) else: os.close(child_w) diff --git a/Lib/multiprocessing/popen_spawn_win32.py b/Lib/multiprocessing/popen_spawn_win32.py index 49d4c7eea2..62fb0ddbf9 100644 --- a/Lib/multiprocessing/popen_spawn_win32.py +++ b/Lib/multiprocessing/popen_spawn_win32.py @@ -3,6 +3,7 @@ import signal import sys import _winapi +from subprocess import STARTUPINFO, STARTF_FORCEOFFFEEDBACK from .context import reduction, get_spawning_popen, set_spawning_popen from . import spawn @@ -74,7 +75,8 @@ def __init__(self, process_obj): try: hp, ht, pid, tid = _winapi.CreateProcess( python_exe, cmd, - None, None, False, 0, env, None, None) + None, None, False, 0, env, None, + STARTUPINFO(dwFlags=STARTF_FORCEOFFFEEDBACK)) _winapi.CloseHandle(ht) except: _winapi.CloseHandle(rhandle) diff --git a/Lib/multiprocessing/process.py b/Lib/multiprocessing/process.py index 271ba3fd32..b45f7df476 100644 --- a/Lib/multiprocessing/process.py +++ b/Lib/multiprocessing/process.py @@ -310,11 +310,8 @@ def _bootstrap(self, parent_sentinel=None): # _run_after_forkers() is executed del old_process util.info('child process calling self.run()') - try: - self.run() - exitcode = 0 - finally: - util._exit_function() + self.run() + exitcode = 0 except SystemExit as e: if e.code is None: exitcode = 0 diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py index 852ae87b27..925f043900 100644 --- a/Lib/multiprocessing/queues.py +++ b/Lib/multiprocessing/queues.py @@ -20,8 +20,6 @@ from queue import Empty, Full -import _multiprocessing - from . import connection from . import context _ForkingPickler = context.reduction.ForkingPickler diff --git a/Lib/multiprocessing/resource_tracker.py b/Lib/multiprocessing/resource_tracker.py index 79e96ecf32..05633ac21a 100644 --- a/Lib/multiprocessing/resource_tracker.py +++ b/Lib/multiprocessing/resource_tracker.py @@ -29,8 +29,12 @@ _HAVE_SIGMASK = hasattr(signal, 'pthread_sigmask') _IGNORED_SIGNALS = (signal.SIGINT, signal.SIGTERM) +def cleanup_noop(name): + raise RuntimeError('noop should never be registered or cleaned up') + _CLEANUP_FUNCS = { - 'noop': lambda: None, + 'noop': cleanup_noop, + 'dummy': lambda name: None, # Dummy resource used in tests } if os.name == 'posix': @@ -61,6 +65,7 @@ def __init__(self): self._lock = threading.RLock() self._fd = None self._pid = None + self._exitcode = None def _reentrant_call_error(self): # gh-109629: this happens if an explicit call to the ResourceTracker @@ -70,22 +75,53 @@ def _reentrant_call_error(self): raise ReentrantCallError( "Reentrant call into the multiprocessing resource tracker") - def _stop(self): - with self._lock: - # This should not happen (_stop() isn't called by a finalizer) - # but we check for it anyway. - if self._lock._recursion_count() > 1: - return self._reentrant_call_error() - if self._fd is None: - # not running - return + def __del__(self): + # making sure child processess are cleaned before ResourceTracker + # gets destructed. + # see https://github.com/python/cpython/issues/88887 + self._stop(use_blocking_lock=False) - # closing the "alive" file descriptor stops main() - os.close(self._fd) - self._fd = None + def _stop(self, use_blocking_lock=True): + if use_blocking_lock: + with self._lock: + self._stop_locked() + else: + acquired = self._lock.acquire(blocking=False) + try: + self._stop_locked() + finally: + if acquired: + self._lock.release() + + def _stop_locked( + self, + close=os.close, + waitpid=os.waitpid, + waitstatus_to_exitcode=os.waitstatus_to_exitcode, + ): + # This shouldn't happen (it might when called by a finalizer) + # so we check for it anyway. + if self._lock._recursion_count() > 1: + return self._reentrant_call_error() + if self._fd is None: + # not running + return + if self._pid is None: + return + + # closing the "alive" file descriptor stops main() + close(self._fd) + self._fd = None - os.waitpid(self._pid, 0) - self._pid = None + _, status = waitpid(self._pid, 0) + + self._pid = None + + try: + self._exitcode = waitstatus_to_exitcode(status) + except ValueError: + # os.waitstatus_to_exitcode may raise an exception for invalid values + self._exitcode = None def getfd(self): self.ensure_running() @@ -119,6 +155,7 @@ def ensure_running(self): pass self._fd = None self._pid = None + self._exitcode = None warnings.warn('resource_tracker: process died unexpectedly, ' 'relaunching. Some resources might leak.') @@ -142,13 +179,14 @@ def ensure_running(self): # that can make the child die before it registers signal handlers # for SIGINT and SIGTERM. The mask is unregistered after spawning # the child. + prev_sigmask = None try: if _HAVE_SIGMASK: - signal.pthread_sigmask(signal.SIG_BLOCK, _IGNORED_SIGNALS) + prev_sigmask = signal.pthread_sigmask(signal.SIG_BLOCK, _IGNORED_SIGNALS) pid = util.spawnv_passfds(exe, args, fds_to_pass) finally: - if _HAVE_SIGMASK: - signal.pthread_sigmask(signal.SIG_UNBLOCK, _IGNORED_SIGNALS) + if prev_sigmask is not None: + signal.pthread_sigmask(signal.SIG_SETMASK, prev_sigmask) except: os.close(w) raise @@ -221,6 +259,8 @@ def main(fd): pass cache = {rtype: set() for rtype in _CLEANUP_FUNCS.keys()} + exit_code = 0 + try: # keep track of registered/unregistered resources with open(fd, 'rb') as f: @@ -242,6 +282,7 @@ def main(fd): else: raise RuntimeError('unrecognized command %r' % cmd) except Exception: + exit_code = 3 try: sys.excepthook(*sys.exc_info()) except: @@ -251,9 +292,17 @@ def main(fd): for rtype, rtype_cache in cache.items(): if rtype_cache: try: - warnings.warn('resource_tracker: There appear to be %d ' - 'leaked %s objects to clean up at shutdown' % - (len(rtype_cache), rtype)) + exit_code = 1 + if rtype == 'dummy': + # The test 'dummy' resource is expected to leak. + # We skip the warning (and *only* the warning) for it. + pass + else: + warnings.warn( + f'resource_tracker: There appear to be ' + f'{len(rtype_cache)} leaked {rtype} objects to ' + f'clean up at shutdown: {rtype_cache}' + ) except Exception: pass for name in rtype_cache: @@ -264,6 +313,9 @@ def main(fd): try: _CLEANUP_FUNCS[rtype](name) except Exception as e: + exit_code = 2 warnings.warn('resource_tracker: %r: %s' % (name, e)) finally: pass + + sys.exit(exit_code) diff --git a/Lib/multiprocessing/shared_memory.py b/Lib/multiprocessing/shared_memory.py index 9a1e5aa17b..67e70fdc27 100644 --- a/Lib/multiprocessing/shared_memory.py +++ b/Lib/multiprocessing/shared_memory.py @@ -71,8 +71,9 @@ class SharedMemory: _flags = os.O_RDWR _mode = 0o600 _prepend_leading_slash = True if _USE_POSIX else False + _track = True - def __init__(self, name=None, create=False, size=0): + def __init__(self, name=None, create=False, size=0, *, track=True): if not size >= 0: raise ValueError("'size' must be a positive integer") if create: @@ -82,6 +83,7 @@ def __init__(self, name=None, create=False, size=0): if name is None and not self._flags & os.O_EXCL: raise ValueError("'name' can only be None if create=True") + self._track = track if _USE_POSIX: # POSIX Shared Memory @@ -116,8 +118,8 @@ def __init__(self, name=None, create=False, size=0): except OSError: self.unlink() raise - - resource_tracker.register(self._name, "shared_memory") + if self._track: + resource_tracker.register(self._name, "shared_memory") else: @@ -236,12 +238,20 @@ def close(self): def unlink(self): """Requests that the underlying shared memory block be destroyed. - In order to ensure proper cleanup of resources, unlink should be - called once (and only once) across all processes which have access - to the shared memory block.""" + Unlink should be called once (and only once) across all handles + which have access to the shared memory block, even if these + handles belong to different processes. Closing and unlinking may + happen in any order, but trying to access data inside a shared + memory block after unlinking may result in memory errors, + depending on platform. + + This method has no effect on Windows, where the only way to + delete a shared memory block is to close all handles.""" + if _USE_POSIX and self._name: _posixshmem.shm_unlink(self._name) - resource_tracker.unregister(self._name, "shared_memory") + if self._track: + resource_tracker.unregister(self._name, "shared_memory") _encoding = "utf8" diff --git a/Lib/multiprocessing/synchronize.py b/Lib/multiprocessing/synchronize.py index 3ccbfe311c..870c91349b 100644 --- a/Lib/multiprocessing/synchronize.py +++ b/Lib/multiprocessing/synchronize.py @@ -174,7 +174,7 @@ def __repr__(self): name = process.current_process().name if threading.current_thread().name != 'MainThread': name += '|' + threading.current_thread().name - elif self._semlock._get_value() == 1: + elif not self._semlock._is_zero(): name = 'None' elif self._semlock._count() > 0: name = 'SomeOtherThread' @@ -200,7 +200,7 @@ def __repr__(self): if threading.current_thread().name != 'MainThread': name += '|' + threading.current_thread().name count = self._semlock._count() - elif self._semlock._get_value() == 1: + elif not self._semlock._is_zero(): name, count = 'None', 0 elif self._semlock._count() > 0: name, count = 'SomeOtherThread', 'nonzero' @@ -360,7 +360,7 @@ def wait(self, timeout=None): return True return False - def __repr__(self) -> str: + def __repr__(self): set_status = 'set' if self.is_set() else 'unset' return f"<{type(self).__qualname__} at {id(self):#x} {set_status}>" # diff --git a/Lib/multiprocessing/util.py b/Lib/multiprocessing/util.py index 79559823fb..75dde02d88 100644 --- a/Lib/multiprocessing/util.py +++ b/Lib/multiprocessing/util.py @@ -64,8 +64,7 @@ def get_logger(): global _logger import logging - logging._acquireLock() - try: + with logging._lock: if not _logger: _logger = logging.getLogger(LOGGER_NAME) @@ -79,9 +78,6 @@ def get_logger(): atexit._exithandlers.remove((_exit_function, (), {})) atexit._exithandlers.append((_exit_function, (), {})) - finally: - logging._releaseLock() - return _logger def log_to_stderr(level=None): @@ -106,11 +102,7 @@ def log_to_stderr(level=None): # Abstract socket support def _platform_supports_abstract_sockets(): - if sys.platform == "linux": - return True - if hasattr(sys, 'getandroidapilevel'): - return True - return False + return sys.platform in ("linux", "android") def is_abstract_socket_namespace(address): @@ -130,10 +122,7 @@ def is_abstract_socket_namespace(address): # def _remove_temp_dir(rmtree, tempdir): - def onerror(func, path, err_info): - if not issubclass(err_info[0], FileNotFoundError): - raise - rmtree(tempdir, onerror=onerror) + rmtree(tempdir) current_process = process.current_process() # current_process() can be None if the finalizer is called diff --git a/Lib/selectors.py b/Lib/selectors.py index c3b065b522..b8e5f6a4f7 100644 --- a/Lib/selectors.py +++ b/Lib/selectors.py @@ -66,12 +66,16 @@ def __init__(self, selector): def __len__(self): return len(self._selector._fd_to_key) + def get(self, fileobj, default=None): + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key.get(fd, default) + def __getitem__(self, fileobj): - try: - fd = self._selector._fileobj_lookup(fileobj) - return self._selector._fd_to_key[fd] - except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + fd = self._selector._fileobj_lookup(fileobj) + key = self._selector._fd_to_key.get(fd) + if key is None: + raise KeyError("{!r} is not registered".format(fileobj)) + return key def __iter__(self): return iter(self._selector._fd_to_key) @@ -272,19 +276,6 @@ def close(self): def get_map(self): return self._map - def _key_from_fd(self, fd): - """Return the key associated to a given file descriptor. - - Parameters: - fd -- file descriptor - - Returns: - corresponding key, or None if not found - """ - try: - return self._fd_to_key[fd] - except KeyError: - return None class SelectSelector(_BaseSelectorImpl): @@ -323,17 +314,15 @@ def select(self, timeout=None): r, w, _ = self._select(self._readers, self._writers, [], timeout) except InterruptedError: return ready - r = set(r) - w = set(w) - for fd in r | w: - events = 0 - if fd in r: - events |= EVENT_READ - if fd in w: - events |= EVENT_WRITE - - key = self._key_from_fd(fd) + r = frozenset(r) + w = frozenset(w) + rw = r | w + fd_to_key_get = self._fd_to_key.get + for fd in rw: + key = fd_to_key_get(fd) if key: + events = ((fd in r and EVENT_READ) + | (fd in w and EVENT_WRITE)) ready.append((key, events & key.events)) return ready @@ -350,11 +339,8 @@ def __init__(self): def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) - poller_events = 0 - if events & EVENT_READ: - poller_events |= self._EVENT_READ - if events & EVENT_WRITE: - poller_events |= self._EVENT_WRITE + poller_events = ((events & EVENT_READ and self._EVENT_READ) + | (events & EVENT_WRITE and self._EVENT_WRITE) ) try: self._selector.register(key.fd, poller_events) except: @@ -380,11 +366,8 @@ def modify(self, fileobj, events, data=None): changed = False if events != key.events: - selector_events = 0 - if events & EVENT_READ: - selector_events |= self._EVENT_READ - if events & EVENT_WRITE: - selector_events |= self._EVENT_WRITE + selector_events = ((events & EVENT_READ and self._EVENT_READ) + | (events & EVENT_WRITE and self._EVENT_WRITE)) try: self._selector.modify(key.fd, selector_events) except: @@ -415,15 +398,13 @@ def select(self, timeout=None): fd_event_list = self._selector.poll(timeout) except InterruptedError: return ready - for fd, event in fd_event_list: - events = 0 - if event & ~self._EVENT_READ: - events |= EVENT_WRITE - if event & ~self._EVENT_WRITE: - events |= EVENT_READ - key = self._key_from_fd(fd) + fd_to_key_get = self._fd_to_key.get + for fd, event in fd_event_list: + key = fd_to_key_get(fd) if key: + events = ((event & ~self._EVENT_READ and EVENT_WRITE) + | (event & ~self._EVENT_WRITE and EVENT_READ)) ready.append((key, events & key.events)) return ready @@ -439,6 +420,9 @@ class PollSelector(_PollLikeSelector): if hasattr(select, 'epoll'): + _NOT_EPOLLIN = ~select.EPOLLIN + _NOT_EPOLLOUT = ~select.EPOLLOUT + class EpollSelector(_PollLikeSelector): """Epoll-based selector.""" _selector_cls = select.epoll @@ -461,22 +445,20 @@ def select(self, timeout=None): # epoll_wait() expects `maxevents` to be greater than zero; # we want to make sure that `select()` can be called when no # FD is registered. - max_ev = max(len(self._fd_to_key), 1) + max_ev = len(self._fd_to_key) or 1 ready = [] try: fd_event_list = self._selector.poll(timeout, max_ev) except InterruptedError: return ready - for fd, event in fd_event_list: - events = 0 - if event & ~select.EPOLLIN: - events |= EVENT_WRITE - if event & ~select.EPOLLOUT: - events |= EVENT_READ - key = self._key_from_fd(fd) + fd_to_key = self._fd_to_key + for fd, event in fd_event_list: + key = fd_to_key.get(fd) if key: + events = ((event & _NOT_EPOLLIN and EVENT_WRITE) + | (event & _NOT_EPOLLOUT and EVENT_READ)) ready.append((key, events & key.events)) return ready @@ -566,17 +548,15 @@ def select(self, timeout=None): kev_list = self._selector.control(None, max_ev, timeout) except InterruptedError: return ready + + fd_to_key_get = self._fd_to_key.get for kev in kev_list: fd = kev.ident flag = kev.filter - events = 0 - if flag == select.KQ_FILTER_READ: - events |= EVENT_READ - if flag == select.KQ_FILTER_WRITE: - events |= EVENT_WRITE - - key = self._key_from_fd(fd) + key = fd_to_key_get(fd) if key: + events = ((flag == select.KQ_FILTER_READ and EVENT_READ) + | (flag == select.KQ_FILTER_WRITE and EVENT_WRITE)) ready.append((key, events & key.events)) return ready diff --git a/Lib/stat.py b/Lib/stat.py index fc024db3f4..1b4ed1ebc9 100644 --- a/Lib/stat.py +++ b/Lib/stat.py @@ -110,22 +110,30 @@ def S_ISWHT(mode): S_IXOTH = 0o0001 # execute by others # Names for file flags - +UF_SETTABLE = 0x0000ffff # owner settable flags UF_NODUMP = 0x00000001 # do not dump file UF_IMMUTABLE = 0x00000002 # file may not be changed UF_APPEND = 0x00000004 # file may only be appended to UF_OPAQUE = 0x00000008 # directory is opaque when viewed through a union stack UF_NOUNLINK = 0x00000010 # file may not be renamed or deleted -UF_COMPRESSED = 0x00000020 # OS X: file is hfs-compressed -UF_HIDDEN = 0x00008000 # OS X: file should not be displayed +UF_COMPRESSED = 0x00000020 # macOS: file is compressed +UF_TRACKED = 0x00000040 # macOS: used for handling document IDs +UF_DATAVAULT = 0x00000080 # macOS: entitlement needed for I/O +UF_HIDDEN = 0x00008000 # macOS: file should not be displayed +SF_SETTABLE = 0xffff0000 # superuser settable flags SF_ARCHIVED = 0x00010000 # file may be archived SF_IMMUTABLE = 0x00020000 # file may not be changed SF_APPEND = 0x00040000 # file may only be appended to +SF_RESTRICTED = 0x00080000 # macOS: entitlement needed for writing SF_NOUNLINK = 0x00100000 # file may not be renamed or deleted SF_SNAPSHOT = 0x00200000 # file is a snapshot file +SF_FIRMLINK = 0x00800000 # macOS: file is a firmlink +SF_DATALESS = 0x40000000 # macOS: file is a dataless object _filemode_table = ( + # File type chars according to: + # http://en.wikibooks.org/wiki/C_Programming/POSIX_Reference/sys/stat.h ((S_IFLNK, "l"), (S_IFSOCK, "s"), # Must appear before IFREG and IFDIR as IFSOCK == IFREG | IFDIR (S_IFREG, "-"), @@ -156,13 +164,17 @@ def S_ISWHT(mode): def filemode(mode): """Convert a file's mode to a string of the form '-rwxrwxrwx'.""" perm = [] - for table in _filemode_table: + for index, table in enumerate(_filemode_table): for bit, char in table: if mode & bit == bit: perm.append(char) break else: - perm.append("-") + if index == 0: + # Unknown filetype + perm.append("?") + else: + perm.append("-") return "".join(perm) diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 9e688efb1e..0b8de96f1b 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -12,6 +12,7 @@ import sys import os import gc +import importlib import errno import functools import signal @@ -19,10 +20,11 @@ import socket import random import logging +import shutil import subprocess import struct +import tempfile import operator -import pathlib import pickle import weakref import warnings @@ -50,7 +52,7 @@ import multiprocessing.managers import multiprocessing.pool import multiprocessing.queues -from multiprocessing.connection import wait, AuthenticationError +from multiprocessing.connection import wait from multiprocessing import util @@ -255,6 +257,9 @@ def __call__(self, *args, **kwds): class BaseTestCase(object): ALLOWED_TYPES = ('processes', 'manager', 'threads') + # If not empty, limit which start method suites run this class. + START_METHODS: set[str] = set() + start_method = None # set by install_tests_in_module_dict() def assertTimingAlmostEqual(self, a, b): if CHECK_TIMINGS: @@ -324,8 +329,9 @@ def test_set_executable(self): self.skipTest(f'test not appropriate for {self.TYPE}') paths = [ sys.executable, # str - sys.executable.encode(), # bytes - pathlib.Path(sys.executable) # os.PathLike + os.fsencode(sys.executable), # bytes + os_helper.FakePath(sys.executable), # os.PathLike + os_helper.FakePath(os.fsencode(sys.executable)), # os.PathLike bytes ] for path in paths: self.set_executable(path) @@ -505,6 +511,11 @@ def _test_process_mainthread_native_id(cls, q): def _sleep_some(cls): time.sleep(100) + @classmethod + def _sleep_some_event(cls, event): + event.set() + time.sleep(100) + @classmethod def _test_sleep(cls, delay): time.sleep(delay) @@ -513,7 +524,8 @@ def _kill_process(self, meth): if self.TYPE == 'threads': self.skipTest('test not appropriate for {}'.format(self.TYPE)) - p = self.Process(target=self._sleep_some) + event = self.Event() + p = self.Process(target=self._sleep_some_event, args=(event,)) p.daemon = True p.start() @@ -531,8 +543,11 @@ def _kill_process(self, meth): self.assertTimingAlmostEqual(join.elapsed, 0.0) self.assertEqual(p.is_alive(), True) - # XXX maybe terminating too soon causes the problems on Gentoo... - time.sleep(1) + timeout = support.SHORT_TIMEOUT + if not event.wait(timeout): + p.terminate() + p.join() + self.fail(f"event not signaled in {timeout} seconds") meth(p) @@ -582,12 +597,16 @@ def test_cpu_count(self): def test_active_children(self): self.assertEqual(type(self.active_children()), list) - p = self.Process(target=time.sleep, args=(DELTA,)) + event = self.Event() + p = self.Process(target=event.wait, args=()) self.assertNotIn(p, self.active_children()) - p.daemon = True - p.start() - self.assertIn(p, self.active_children()) + try: + p.daemon = True + p.start() + self.assertIn(p, self.active_children()) + finally: + event.set() p.join() self.assertNotIn(p, self.active_children()) @@ -1332,6 +1351,23 @@ def _on_queue_feeder_error(e, obj): self.assertTrue(not_serializable_obj.reduce_was_called) self.assertTrue(not_serializable_obj.on_queue_feeder_error_was_called) + def test_closed_queue_empty_exceptions(self): + # Assert that checking the emptiness of an unused closed queue + # does not raise an OSError. The rationale is that q.close() is + # a no-op upon construction and becomes effective once the queue + # has been used (e.g., by calling q.put()). + for q in multiprocessing.Queue(), multiprocessing.JoinableQueue(): + q.close() # this is a no-op since the feeder thread is None + q.join_thread() # this is also a no-op + self.assertTrue(q.empty()) + + for q in multiprocessing.Queue(), multiprocessing.JoinableQueue(): + q.put('foo') # make sure that the queue is 'used' + q.close() # close the feeder thread + q.join_thread() # make sure to join the feeder thread + with self.assertRaisesRegex(OSError, 'is closed'): + q.empty() + def test_closed_queue_put_get_exceptions(self): for q in multiprocessing.Queue(), multiprocessing.JoinableQueue(): q.close() @@ -1345,6 +1381,66 @@ def test_closed_queue_put_get_exceptions(self): class _TestLock(BaseTestCase): + @staticmethod + def _acquire(lock, l=None): + lock.acquire() + if l is not None: + l.append(repr(lock)) + + @staticmethod + def _acquire_event(lock, event): + lock.acquire() + event.set() + time.sleep(1.0) + + def test_repr_lock(self): + if self.TYPE != 'processes': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + lock = self.Lock() + self.assertEqual(f'', repr(lock)) + + lock.acquire() + self.assertEqual(f'', repr(lock)) + lock.release() + + tname = 'T1' + l = [] + t = threading.Thread(target=self._acquire, + args=(lock, l), + name=tname) + t.start() + time.sleep(0.1) + self.assertEqual(f'', l[0]) + lock.release() + + t = threading.Thread(target=self._acquire, + args=(lock,), + name=tname) + t.start() + time.sleep(0.1) + self.assertEqual('', repr(lock)) + lock.release() + + pname = 'P1' + l = multiprocessing.Manager().list() + p = self.Process(target=self._acquire, + args=(lock, l), + name=pname) + p.start() + p.join() + self.assertEqual(f'', l[0]) + + lock = self.Lock() + event = self.Event() + p = self.Process(target=self._acquire_event, + args=(lock, event), + name='P2') + p.start() + event.wait() + self.assertEqual(f'', repr(lock)) + p.terminate() + def test_lock(self): lock = self.Lock() self.assertEqual(lock.acquire(), True) @@ -1352,6 +1448,68 @@ def test_lock(self): self.assertEqual(lock.release(), None) self.assertRaises((ValueError, threading.ThreadError), lock.release) + @staticmethod + def _acquire_release(lock, timeout, l=None, n=1): + for _ in range(n): + lock.acquire() + if l is not None: + l.append(repr(lock)) + time.sleep(timeout) + for _ in range(n): + lock.release() + + def test_repr_rlock(self): + if self.TYPE != 'processes': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + lock = self.RLock() + self.assertEqual('', repr(lock)) + + n = 3 + for _ in range(n): + lock.acquire() + self.assertEqual(f'', repr(lock)) + for _ in range(n): + lock.release() + + t, l = [], [] + for i in range(n): + t.append(threading.Thread(target=self._acquire_release, + args=(lock, 0.1, l, i+1), + name=f'T{i+1}')) + t[-1].start() + for t_ in t: + t_.join() + for i in range(n): + self.assertIn(f'', l) + + + t = threading.Thread(target=self._acquire_release, + args=(lock, 0.2), + name=f'T1') + t.start() + time.sleep(0.1) + self.assertEqual('', repr(lock)) + time.sleep(0.2) + + pname = 'P1' + l = multiprocessing.Manager().list() + p = self.Process(target=self._acquire_release, + args=(lock, 0.1, l), + name=pname) + p.start() + p.join() + self.assertEqual(f'', l[0]) + + event = self.Event() + lock = self.RLock() + p = self.Process(target=self._acquire_event, + args=(lock, event)) + p.start() + event.wait() + self.assertEqual('', repr(lock)) + p.join() + def test_rlock(self): lock = self.RLock() self.assertEqual(lock.acquire(), True) @@ -1432,14 +1590,13 @@ def f(cls, cond, sleeping, woken, timeout=None): cond.release() def assertReachesEventually(self, func, value): - for i in range(10): + for _ in support.sleeping_retry(support.SHORT_TIMEOUT): try: if func() == value: break except NotImplementedError: break - time.sleep(DELTA) - time.sleep(DELTA) + self.assertReturnsIfImplemented(value, func) def check_invariant(self, cond): @@ -1461,20 +1618,17 @@ def test_notify(self): p = self.Process(target=self.f, args=(cond, sleeping, woken)) p.daemon = True p.start() - self.addCleanup(p.join) - p = threading.Thread(target=self.f, args=(cond, sleeping, woken)) - p.daemon = True - p.start() - self.addCleanup(p.join) + t = threading.Thread(target=self.f, args=(cond, sleeping, woken)) + t.daemon = True + t.start() # wait for both children to start sleeping sleeping.acquire() sleeping.acquire() # check no process/thread has woken up - time.sleep(DELTA) - self.assertReturnsIfImplemented(0, get_value, woken) + self.assertReachesEventually(lambda: get_value(woken), 0) # wake up one process/thread cond.acquire() @@ -1482,8 +1636,7 @@ def test_notify(self): cond.release() # check one process/thread has woken up - time.sleep(DELTA) - self.assertReturnsIfImplemented(1, get_value, woken) + self.assertReachesEventually(lambda: get_value(woken), 1) # wake up another cond.acquire() @@ -1491,12 +1644,13 @@ def test_notify(self): cond.release() # check other has woken up - time.sleep(DELTA) - self.assertReturnsIfImplemented(2, get_value, woken) + self.assertReachesEventually(lambda: get_value(woken), 2) # check state is not mucked up self.check_invariant(cond) - p.join() + + threading_helper.join_thread(t) + join_process(p) def test_notify_all(self): cond = self.Condition() @@ -1504,18 +1658,19 @@ def test_notify_all(self): woken = self.Semaphore(0) # start some threads/processes which will timeout + workers = [] for i in range(3): p = self.Process(target=self.f, args=(cond, sleeping, woken, TIMEOUT1)) p.daemon = True p.start() - self.addCleanup(p.join) + workers.append(p) t = threading.Thread(target=self.f, args=(cond, sleeping, woken, TIMEOUT1)) t.daemon = True t.start() - self.addCleanup(t.join) + workers.append(t) # wait for them all to sleep for i in range(6): @@ -1534,12 +1689,12 @@ def test_notify_all(self): p = self.Process(target=self.f, args=(cond, sleeping, woken)) p.daemon = True p.start() - self.addCleanup(p.join) + workers.append(p) t = threading.Thread(target=self.f, args=(cond, sleeping, woken)) t.daemon = True t.start() - self.addCleanup(t.join) + workers.append(t) # wait for them to all sleep for i in range(6): @@ -1555,27 +1710,34 @@ def test_notify_all(self): cond.release() # check they have all woken - self.assertReachesEventually(lambda: get_value(woken), 6) + for i in range(6): + woken.acquire() + self.assertReturnsIfImplemented(0, get_value, woken) # check state is not mucked up self.check_invariant(cond) + for w in workers: + # NOTE: join_process and join_thread are the same + threading_helper.join_thread(w) + def test_notify_n(self): cond = self.Condition() sleeping = self.Semaphore(0) woken = self.Semaphore(0) # start some threads/processes + workers = [] for i in range(3): p = self.Process(target=self.f, args=(cond, sleeping, woken)) p.daemon = True p.start() - self.addCleanup(p.join) + workers.append(p) t = threading.Thread(target=self.f, args=(cond, sleeping, woken)) t.daemon = True t.start() - self.addCleanup(t.join) + workers.append(t) # wait for them to all sleep for i in range(6): @@ -1610,6 +1772,10 @@ def test_notify_n(self): # check state is not mucked up self.check_invariant(cond) + for w in workers: + # NOTE: join_process and join_thread are the same + threading_helper.join_thread(w) + def test_timeout(self): cond = self.Condition() wait = TimingWrapper(cond.wait) @@ -2812,8 +2978,8 @@ def test_release_task_refs(self): self.pool.map(identity, objs) del objs - gc.collect() # For PyPy or other GCs. time.sleep(DELTA) # let threaded cleanup code run + support.gc_collect() # For PyPy or other GCs. self.assertEqual(set(wr() for wr in refs), {None}) # With a process pool, copies of the objects are returned, check # they were released too. @@ -3174,6 +3340,44 @@ def test_rapid_restart(self): if hasattr(manager, "shutdown"): self.addCleanup(manager.shutdown) + +class FakeConnection: + def send(self, payload): + pass + + def recv(self): + return '#ERROR', pyqueue.Empty() + +class TestManagerExceptions(unittest.TestCase): + # Issue 106558: Manager exceptions avoids creating cyclic references. + def setUp(self): + self.mgr = multiprocessing.Manager() + + def tearDown(self): + self.mgr.shutdown() + self.mgr.join() + + def test_queue_get(self): + queue = self.mgr.Queue() + if gc.isenabled(): + gc.disable() + self.addCleanup(gc.enable) + try: + queue.get_nowait() + except pyqueue.Empty as e: + wr = weakref.ref(e) + self.assertEqual(wr(), None) + + def test_dispatch(self): + if gc.isenabled(): + gc.disable() + self.addCleanup(gc.enable) + try: + multiprocessing.managers.dispatch(FakeConnection(), None, None) + except pyqueue.Empty as e: + wr = weakref.ref(e) + self.assertEqual(wr(), None) + # # # @@ -4462,6 +4666,59 @@ def test_shared_memory_cleaned_after_process_termination(self): "resource_tracker: There appear to be 1 leaked " "shared_memory objects to clean up at shutdown", err) + @unittest.skipIf(os.name != "posix", "resource_tracker is posix only") + def test_shared_memory_untracking(self): + # gh-82300: When a separate Python process accesses shared memory + # with track=False, it must not cause the memory to be deleted + # when terminating. + cmd = '''if 1: + import sys + from multiprocessing.shared_memory import SharedMemory + mem = SharedMemory(create=False, name=sys.argv[1], track=False) + mem.close() + ''' + mem = shared_memory.SharedMemory(create=True, size=10) + # The resource tracker shares pipes with the subprocess, and so + # err existing means that the tracker process has terminated now. + try: + rc, out, err = script_helper.assert_python_ok("-c", cmd, mem.name) + self.assertNotIn(b"resource_tracker", err) + self.assertEqual(rc, 0) + mem2 = shared_memory.SharedMemory(create=False, name=mem.name) + mem2.close() + finally: + try: + mem.unlink() + except OSError: + pass + mem.close() + + @unittest.skipIf(os.name != "posix", "resource_tracker is posix only") + def test_shared_memory_tracking(self): + # gh-82300: When a separate Python process accesses shared memory + # with track=True, it must cause the memory to be deleted when + # terminating. + cmd = '''if 1: + import sys + from multiprocessing.shared_memory import SharedMemory + mem = SharedMemory(create=False, name=sys.argv[1], track=True) + mem.close() + ''' + mem = shared_memory.SharedMemory(create=True, size=10) + try: + rc, out, err = script_helper.assert_python_ok("-c", cmd, mem.name) + self.assertEqual(rc, 0) + self.assertIn( + b"resource_tracker: There appear to be 1 leaked " + b"shared_memory objects to clean up at shutdown", err) + finally: + try: + mem.unlink() + except OSError: + pass + resource_tracker.unregister(mem._name, "shared_memory") + mem.close() + # # Test to verify that `Finalize` works. # @@ -4571,7 +4828,7 @@ def make_finalizers(): old_interval = sys.getswitchinterval() old_threshold = gc.get_threshold() try: - sys.setswitchinterval(1e-6) + support.setswitchinterval(1e-6) gc.set_threshold(5, 5, 5) threads = [threading.Thread(target=run_finalizers), threading.Thread(target=make_finalizers)] @@ -5557,8 +5814,9 @@ def create_and_register_resource(rtype): ''' for rtype in resource_tracker._CLEANUP_FUNCS: with self.subTest(rtype=rtype): - if rtype == "noop": + if rtype in ("noop", "dummy"): # Artefact resource type used by the resource_tracker + # or tests continue r, w = os.pipe() p = subprocess.Popen([sys.executable, @@ -5638,6 +5896,8 @@ def test_resource_tracker_sigterm(self): # Catchable signal (ignored by semaphore tracker) self.check_resource_tracker_death(signal.SIGTERM, False) + @unittest.skipIf(sys.platform.startswith("netbsd"), + "gh-125620: Skip on NetBSD due to long wait for SIGKILL process termination.") def test_resource_tracker_sigkill(self): # Uncatchable signal. self.check_resource_tracker_death(signal.SIGKILL, True) @@ -5678,6 +5938,59 @@ def test_too_long_name_resource(self): with self.assertRaises(ValueError): resource_tracker.register(too_long_name_resource, rtype) + def _test_resource_tracker_leak_resources(self, cleanup): + # We use a separate instance for testing, since the main global + # _resource_tracker may be used to watch test infrastructure. + from multiprocessing.resource_tracker import ResourceTracker + tracker = ResourceTracker() + tracker.ensure_running() + self.assertTrue(tracker._check_alive()) + + self.assertIsNone(tracker._exitcode) + tracker.register('somename', 'dummy') + if cleanup: + tracker.unregister('somename', 'dummy') + expected_exit_code = 0 + else: + expected_exit_code = 1 + + self.assertTrue(tracker._check_alive()) + self.assertIsNone(tracker._exitcode) + tracker._stop() + self.assertEqual(tracker._exitcode, expected_exit_code) + + def test_resource_tracker_exit_code(self): + """ + Test the exit code of the resource tracker. + + If no leaked resources were found, exit code should be 0, otherwise 1 + """ + for cleanup in [True, False]: + with self.subTest(cleanup=cleanup): + self._test_resource_tracker_leak_resources( + cleanup=cleanup, + ) + + @unittest.skipUnless(hasattr(signal, "pthread_sigmask"), "pthread_sigmask is not available") + def test_resource_tracker_blocked_signals(self): + # + # gh-127586: Check that resource_tracker does not override blocked signals of caller. + # + from multiprocessing.resource_tracker import ResourceTracker + orig_sigmask = signal.pthread_sigmask(signal.SIG_BLOCK, set()) + signals = {signal.SIGTERM, signal.SIGINT, signal.SIGUSR1} + + try: + for sig in signals: + signal.pthread_sigmask(signal.SIG_SETMASK, {sig}) + self.assertEqual(signal.pthread_sigmask(signal.SIG_BLOCK, set()), {sig}) + tracker = ResourceTracker() + tracker.ensure_running() + self.assertEqual(signal.pthread_sigmask(signal.SIG_BLOCK, set()), {sig}) + tracker._stop() + finally: + # restore sigmask to what it was before executing test + signal.pthread_sigmask(signal.SIG_SETMASK, orig_sigmask) class TestSimpleQueue(unittest.TestCase): @@ -5691,6 +6004,15 @@ def _test_empty(cls, queue, child_can_start, parent_can_continue): finally: parent_can_continue.set() + def test_empty_exceptions(self): + # Assert that checking emptiness of a closed queue raises + # an OSError, independently of whether the queue was used + # or not. This differs from Queue and JoinableQueue. + q = multiprocessing.SimpleQueue() + q.close() # close the pipe + with self.assertRaisesRegex(OSError, 'is closed'): + q.empty() + def test_empty(self): queue = multiprocessing.SimpleQueue() child_can_start = multiprocessing.Event() @@ -6037,6 +6359,99 @@ def submain(): pass self.assertFalse(err, msg=err.decode('utf-8')) +class _TestAtExit(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + @classmethod + def _write_file_at_exit(self, output_path): + import atexit + def exit_handler(): + with open(output_path, 'w') as f: + f.write("deadbeef") + atexit.register(exit_handler) + + def test_atexit(self): + # gh-83856 + with os_helper.temp_dir() as temp_dir: + output_path = os.path.join(temp_dir, 'output.txt') + p = self.Process(target=self._write_file_at_exit, args=(output_path,)) + p.start() + p.join() + with open(output_path) as f: + self.assertEqual(f.read(), 'deadbeef') + + +class _TestSpawnedSysPath(BaseTestCase): + """Test that sys.path is setup in forkserver and spawn processes.""" + + ALLOWED_TYPES = {'processes'} + # Not applicable to fork which inherits everything from the process as is. + START_METHODS = {"forkserver", "spawn"} + + def setUp(self): + self._orig_sys_path = list(sys.path) + self._temp_dir = tempfile.mkdtemp(prefix="test_sys_path-") + self._mod_name = "unique_test_mod" + module_path = os.path.join(self._temp_dir, f"{self._mod_name}.py") + with open(module_path, "w", encoding="utf-8") as mod: + mod.write("# A simple test module\n") + sys.path[:] = [p for p in sys.path if p] # remove any existing ""s + sys.path.insert(0, self._temp_dir) + sys.path.insert(0, "") # Replaced with an abspath in child. + self.assertIn(self.start_method, self.START_METHODS) + self._ctx = multiprocessing.get_context(self.start_method) + + def tearDown(self): + sys.path[:] = self._orig_sys_path + shutil.rmtree(self._temp_dir, ignore_errors=True) + + @staticmethod + def enq_imported_module_names(queue): + queue.put(tuple(sys.modules)) + + def test_forkserver_preload_imports_sys_path(self): + if self._ctx.get_start_method() != "forkserver": + self.skipTest("forkserver specific test.") + self.assertNotIn(self._mod_name, sys.modules) + multiprocessing.forkserver._forkserver._stop() # Must be fresh. + self._ctx.set_forkserver_preload( + ["test.test_multiprocessing_forkserver", self._mod_name]) + q = self._ctx.Queue() + proc = self._ctx.Process( + target=self.enq_imported_module_names, args=(q,)) + proc.start() + proc.join() + child_imported_modules = q.get() + q.close() + self.assertIn(self._mod_name, child_imported_modules) + + @staticmethod + def enq_sys_path_and_import(queue, mod_name): + queue.put(sys.path) + try: + importlib.import_module(mod_name) + except ImportError as exc: + queue.put(exc) + else: + queue.put(None) + + def test_child_sys_path(self): + q = self._ctx.Queue() + proc = self._ctx.Process( + target=self.enq_sys_path_and_import, args=(q, self._mod_name)) + proc.start() + proc.join() + child_sys_path = q.get() + import_error = q.get() + q.close() + self.assertNotIn("", child_sys_path) # replaced by an abspath + self.assertIn(self._temp_dir, child_sys_path) # our addition + # ignore the first element, it is the absolute "" replacement + self.assertEqual(child_sys_path[1:], sys.path[1:]) + self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}") + + class MiscTestCase(unittest.TestCase): def test__all__(self): # Just make sure names in not_exported are excluded @@ -6061,6 +6476,46 @@ def test_spawn_sys_executable_none_allows_import(self): self.assertEqual(rc, 0) self.assertFalse(err, msg=err.decode('utf-8')) + def test_large_pool(self): + # + # gh-89240: Check that large pools are always okay + # + testfn = os_helper.TESTFN + self.addCleanup(os_helper.unlink, testfn) + with open(testfn, 'w', encoding='utf-8') as f: + f.write(textwrap.dedent('''\ + import multiprocessing + def f(x): return x*x + if __name__ == '__main__': + with multiprocessing.Pool(200) as p: + print(sum(p.map(f, range(1000)))) + ''')) + rc, out, err = script_helper.assert_python_ok(testfn) + self.assertEqual("332833500", out.decode('utf-8').strip()) + self.assertFalse(err, msg=err.decode('utf-8')) + + def test_forked_thread_not_started(self): + # gh-134381: Ensure that a thread that has not been started yet in + # the parent process can be started within a forked child process. + + if multiprocessing.get_start_method() != "fork": + self.skipTest("fork specific test") + + q = multiprocessing.Queue() + t = threading.Thread(target=lambda: q.put("done"), daemon=True) + + def child(): + t.start() + t.join() + + p = multiprocessing.Process(target=child) + p.start() + p.join(support.SHORT_TIMEOUT) + + self.assertEqual(p.exitcode, 0) + self.assertEqual(q.get_nowait(), "done") + close_queue(q) + # # Mixins @@ -6213,6 +6668,8 @@ def install_tests_in_module_dict(remote_globs, start_method, if base is BaseTestCase: continue assert set(base.ALLOWED_TYPES) <= ALL_TYPES, base.ALLOWED_TYPES + if base.START_METHODS and start_method not in base.START_METHODS: + continue # class not intended for this start method. for type_ in base.ALLOWED_TYPES: if only_type and type_ != only_type: continue @@ -6226,6 +6683,7 @@ class Temp(base, Mixin, unittest.TestCase): Temp = hashlib_helper.requires_hashdigest('sha256')(Temp) Temp.__name__ = Temp.__qualname__ = newname Temp.__module__ = __module__ + Temp.start_method = start_method remote_globs[newname] = Temp elif issubclass(base, unittest.TestCase): if only_type: diff --git a/Lib/test/_test_venv_multiprocessing.py b/Lib/test/_test_venv_multiprocessing.py new file mode 100644 index 0000000000..ad985dd8d5 --- /dev/null +++ b/Lib/test/_test_venv_multiprocessing.py @@ -0,0 +1,40 @@ +import multiprocessing +import random +import sys + +def fill_queue(queue, code): + queue.put(code) + + +def drain_queue(queue, code): + if code != queue.get(): + sys.exit(1) + + +def test_func(): + code = random.randrange(0, 1000) + queue = multiprocessing.Queue() + fill_pool = multiprocessing.Process( + target=fill_queue, + args=(queue, code) + ) + drain_pool = multiprocessing.Process( + target=drain_queue, + args=(queue, code) + ) + drain_pool.start() + fill_pool.start() + fill_pool.join() + drain_pool.join() + + +def main(): + multiprocessing.set_start_method('spawn') + test_pool = multiprocessing.Process(target=test_func) + test_pool.start() + test_pool.join() + sys.exit(test_pool.exitcode) + + +if __name__ == "__main__": + main() diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py index be89bec522..0c20e27cfd 100644 --- a/Lib/test/test_array.py +++ b/Lib/test/test_array.py @@ -1285,8 +1285,6 @@ def check_overflow(self, lower, upper): self.assertRaises(OverflowError, array.array, self.typecode, [upper+1]) self.assertRaises(OverflowError, a.__setitem__, 0, upper+1) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_subclassing(self): typecode = self.typecode class ExaggeratingArray(array.array): diff --git a/Lib/test/test_importlib/test_metadata_api.py b/Lib/test/test_importlib/test_metadata_api.py index 55c9f8007e..33c6e85ee9 100644 --- a/Lib/test/test_importlib/test_metadata_api.py +++ b/Lib/test/test_importlib/test_metadata_api.py @@ -139,8 +139,6 @@ def test_entry_points_missing_name(self): def test_entry_points_missing_group(self): assert entry_points(group='missing') == () - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_entry_points_allows_no_attributes(self): ep = entry_points().select(group='entries', name='main') with self.assertRaises(AttributeError): diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 370685f1c6..0039d34b5e 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -100,8 +100,7 @@ def setUp(self): self._threading_key = threading_helper.threading_setup() logger_dict = logging.getLogger().manager.loggerDict - logging._acquireLock() - try: + with logging._lock: self.saved_handlers = logging._handlers.copy() self.saved_handler_list = logging._handlerList[:] self.saved_loggers = saved_loggers = logger_dict.copy() @@ -111,8 +110,6 @@ def setUp(self): for name in saved_loggers: logger_states[name] = getattr(saved_loggers[name], 'disabled', None) - finally: - logging._releaseLock() # Set two unused loggers self.logger1 = logging.getLogger("\xab\xd7\xbb") @@ -146,8 +143,7 @@ def tearDown(self): self.root_logger.removeHandler(h) h.close() self.root_logger.setLevel(self.original_logging_level) - logging._acquireLock() - try: + with logging._lock: logging._levelToName.clear() logging._levelToName.update(self.saved_level_to_name) logging._nameToLevel.clear() @@ -164,8 +160,6 @@ def tearDown(self): for name in self.logger_states: if logger_states[name] is not None: self.saved_loggers[name].disabled = logger_states[name] - finally: - logging._releaseLock() self.doCleanups() threading_helper.threading_cleanup(*self._threading_key) @@ -182,7 +176,7 @@ def assert_log_lines(self, expected_values, stream=None, pat=None): match = pat.search(actual) if not match: self.fail("Log line does not match expected pattern:\n" + - actual) + actual) self.assertEqual(tuple(match.groups()), expected) s = stream.read() if s: @@ -588,7 +582,7 @@ def test_specific_filters(self): ('Effusive', '17'), ('Terse', '18'), ('Silent', '20'), - ]) + ]) finally: if specific_filter: self.root_logger.removeFilter(specific_filter) @@ -613,7 +607,7 @@ def test_name(self): def test_builtin_handlers(self): # We can't actually *use* too many handlers in the tests, # but we can try instantiating them with various options - if sys.platform in ('linux', 'darwin'): + if sys.platform in ('linux', 'android', 'darwin'): for existing in (True, False): fn = make_temp_file() if not existing: @@ -673,11 +667,11 @@ def test_pathlike_objects(self): os.unlink(fn) pfn = os_helper.FakePath(fn) cases = ( - (logging.FileHandler, (pfn, 'w')), - (logging.handlers.RotatingFileHandler, (pfn, 'a')), - (logging.handlers.TimedRotatingFileHandler, (pfn, 'h')), - ) - if sys.platform in ('linux', 'darwin'): + (logging.FileHandler, (pfn, 'w')), + (logging.handlers.RotatingFileHandler, (pfn, 'a')), + (logging.handlers.TimedRotatingFileHandler, (pfn, 'h')), + ) + if sys.platform in ('linux', 'android', 'darwin'): cases += ((logging.handlers.WatchedFileHandler, (pfn, 'w')),) for cls, args in cases: h = cls(*args, encoding="utf-8") @@ -751,11 +745,8 @@ def __init__(self): stream=open('/dev/null', 'wt', encoding='utf-8')) def emit(self, record): - self.sub_handler.acquire() - try: + with self.sub_handler.lock: self.sub_handler.emit(record) - finally: - self.sub_handler.release() self.assertEqual(len(logging._handlers), 0) refed_h = _OurHandler() @@ -771,33 +762,26 @@ def emit(self, record): fork_happened__release_locks_and_end_thread = threading.Event() def lock_holder_thread_fn(): - logging._acquireLock() - try: - refed_h.acquire() - try: - # Tell the main thread to do the fork. - locks_held__ready_to_fork.set() - - # If the deadlock bug exists, the fork will happen - # without dealing with the locks we hold, deadlocking - # the child. - - # Wait for a successful fork or an unreasonable amount of - # time before releasing our locks. To avoid a timing based - # test we'd need communication from os.fork() as to when it - # has actually happened. Given this is a regression test - # for a fixed issue, potentially less reliably detecting - # regression via timing is acceptable for simplicity. - # The test will always take at least this long. :( - fork_happened__release_locks_and_end_thread.wait(0.5) - finally: - refed_h.release() - finally: - logging._releaseLock() + with logging._lock, refed_h.lock: + # Tell the main thread to do the fork. + locks_held__ready_to_fork.set() + + # If the deadlock bug exists, the fork will happen + # without dealing with the locks we hold, deadlocking + # the child. + + # Wait for a successful fork or an unreasonable amount of + # time before releasing our locks. To avoid a timing based + # test we'd need communication from os.fork() as to when it + # has actually happened. Given this is a regression test + # for a fixed issue, potentially less reliably detecting + # regression via timing is acceptable for simplicity. + # The test will always take at least this long. :( + fork_happened__release_locks_and_end_thread.wait(0.5) lock_holder_thread = threading.Thread( - target=lock_holder_thread_fn, - name='test_post_fork_child_no_deadlock lock holder') + target=lock_holder_thread_fn, + name='test_post_fork_child_no_deadlock lock holder') lock_holder_thread.start() locks_held__ready_to_fork.wait() @@ -1132,7 +1116,7 @@ class SMTPHandlerTest(BaseTest): TIMEOUT = support.LONG_TIMEOUT # TODO: RUSTPYTHON - @unittest.skip(reason="Hangs RustPython") + @unittest.skip(reason="RUSTPYTHON hangs") def test_basic(self): sockmap = {} server = TestSMTPServer((socket_helper.HOST, 0), self.process_message, 0.001, @@ -1657,7 +1641,7 @@ def test_config4_ok(self): logging.exception("just testing") sys.stdout.seek(0) self.assertEqual(output.getvalue(), - "ERROR:root:just testing\nGot a [RuntimeError]\n") + "ERROR:root:just testing\nGot a [RuntimeError]\n") # Original logger output is empty self.assert_log_lines([]) @@ -2171,7 +2155,7 @@ def handle_request(self, request): self.handled.set() # TODO: RUSTPYTHON - @unittest.skip("RUSTPYTHON") + @unittest.skip("TODO: RUSTPYTHON; flaky test") def test_output(self): # The log message sent to the HTTPHandler is properly received. logger = logging.getLogger("http") @@ -2195,7 +2179,7 @@ def test_output(self): sslctx = None context = None self.server = server = TestHTTPServer(addr, self.handle_request, - 0.01, sslctx=sslctx) + 0.01, sslctx=sslctx) server.start() server.ready.wait() host = 'localhost:%d' % server.server_port @@ -2212,7 +2196,8 @@ def test_output(self): self.handled.clear() msg = "sp\xe4m" logger.error(msg) - self.handled.wait() + handled = self.handled.wait(support.SHORT_TIMEOUT) + self.assertTrue(handled, "HTTP request timed out") self.assertEqual(self.log_data.path, '/frob') self.assertEqual(self.command, method) if method == 'GET': @@ -2253,7 +2238,7 @@ def _assertTruesurvival(self): dead.append(repr_) if dead: self.fail("%d objects should have survived " - "but have been destroyed: %s" % (len(dead), ", ".join(dead))) + "but have been destroyed: %s" % (len(dead), ", ".join(dead))) def test_persistent_loggers(self): # Logger objects are persistent and retain their configuration, even @@ -2351,7 +2336,7 @@ def test_warnings(self): s = a_file.getvalue() a_file.close() self.assertEqual(s, - "dummy.py:42: UserWarning: Explicit\n Dummy line\n") + "dummy.py:42: UserWarning: Explicit\n Dummy line\n") def test_warnings_no_handlers(self): with warnings.catch_warnings(): @@ -2632,7 +2617,7 @@ class ConfigDictTest(BaseTest): }, 'root' : { 'level' : 'NOTSET', - 'handlers' : ['hand1'], + 'handlers' : ['hand1'], }, } @@ -2666,7 +2651,7 @@ class ConfigDictTest(BaseTest): }, 'root' : { 'level' : 'NOTSET', - 'handlers' : ['hand1'], + 'handlers' : ['hand1'], }, } @@ -3277,7 +3262,7 @@ def format(self, record): 'h1' : { 'class': 'logging.FileHandler', }, - # key is before depended on handlers to test that deferred config works + # key is before depended on handlers to test that deferred config works 'ah' : { 'class': 'logging.handlers.QueueHandler', 'handlers': ['h1'] @@ -3355,7 +3340,7 @@ def test_config4_ok(self): logging.exception("just testing") sys.stdout.seek(0) self.assertEqual(output.getvalue(), - "ERROR:root:just testing\nGot a [RuntimeError]\n") + "ERROR:root:just testing\nGot a [RuntimeError]\n") # Original logger output is empty self.assert_log_lines([]) @@ -3370,7 +3355,7 @@ def test_config4a_ok(self): logging.exception("just testing") sys.stdout.seek(0) self.assertEqual(output.getvalue(), - "ERROR:root:just testing\nGot a [RuntimeError]\n") + "ERROR:root:just testing\nGot a [RuntimeError]\n") # Original logger output is empty self.assert_log_lines([]) @@ -3770,7 +3755,28 @@ def test_baseconfig(self): d = { 'atuple': (1, 2, 3), 'alist': ['a', 'b', 'c'], - 'adict': {'d': 'e', 'f': 3 }, + 'adict': { + 'd': 'e', 'f': 3 , + 'alpha numeric 1 with spaces' : 5, + 'alpha numeric 1 %( - © ©ß¯' : 9, + 'alpha numeric ] 1 with spaces' : 15, + 'alpha ]] numeric 1 %( - © ©ß¯]' : 19, + ' alpha [ numeric 1 %( - © ©ß¯] ' : 11, + ' alpha ' : 32, + '' : 10, + 'nest4' : { + 'd': 'e', 'f': 3 , + 'alpha numeric 1 with spaces' : 5, + 'alpha numeric 1 %( - © ©ß¯' : 9, + '' : 10, + 'somelist' : ('g', ('h', 'i'), 'j'), + 'somedict' : { + 'a' : 1, + 'a with 1 and space' : 3, + 'a with ( and space' : 4, + } + } + }, 'nest1': ('g', ('h', 'i'), 'j'), 'nest2': ['k', ['l', 'm'], 'n'], 'nest3': ['o', 'cfg://alist', 'p'], @@ -3782,11 +3788,36 @@ def test_baseconfig(self): self.assertEqual(bc.convert('cfg://nest2[1][1]'), 'm') self.assertEqual(bc.convert('cfg://adict.d'), 'e') self.assertEqual(bc.convert('cfg://adict[f]'), 3) + self.assertEqual(bc.convert('cfg://adict[alpha numeric 1 with spaces]'), 5) + self.assertEqual(bc.convert('cfg://adict[alpha numeric 1 %( - © ©ß¯]'), 9) + self.assertEqual(bc.convert('cfg://adict[]'), 10) + self.assertEqual(bc.convert('cfg://adict.nest4.d'), 'e') + self.assertEqual(bc.convert('cfg://adict.nest4[d]'), 'e') + self.assertEqual(bc.convert('cfg://adict[nest4].d'), 'e') + self.assertEqual(bc.convert('cfg://adict[nest4][f]'), 3) + self.assertEqual(bc.convert('cfg://adict[nest4][alpha numeric 1 with spaces]'), 5) + self.assertEqual(bc.convert('cfg://adict[nest4][alpha numeric 1 %( - © ©ß¯]'), 9) + self.assertEqual(bc.convert('cfg://adict[nest4][]'), 10) + self.assertEqual(bc.convert('cfg://adict[nest4][somelist][0]'), 'g') + self.assertEqual(bc.convert('cfg://adict[nest4][somelist][1][0]'), 'h') + self.assertEqual(bc.convert('cfg://adict[nest4][somelist][1][1]'), 'i') + self.assertEqual(bc.convert('cfg://adict[nest4][somelist][2]'), 'j') + self.assertEqual(bc.convert('cfg://adict[nest4].somedict.a'), 1) + self.assertEqual(bc.convert('cfg://adict[nest4].somedict[a]'), 1) + self.assertEqual(bc.convert('cfg://adict[nest4].somedict[a with 1 and space]'), 3) + self.assertEqual(bc.convert('cfg://adict[nest4].somedict[a with ( and space]'), 4) + self.assertEqual(bc.convert('cfg://adict.nest4.somelist[1][1]'), 'i') + self.assertEqual(bc.convert('cfg://adict.nest4.somelist[2]'), 'j') + self.assertEqual(bc.convert('cfg://adict.nest4.somedict.a'), 1) + self.assertEqual(bc.convert('cfg://adict.nest4.somedict[a]'), 1) v = bc.convert('cfg://nest3') self.assertEqual(v.pop(1), ['a', 'b', 'c']) self.assertRaises(KeyError, bc.convert, 'cfg://nosuch') self.assertRaises(ValueError, bc.convert, 'cfg://!') self.assertRaises(KeyError, bc.convert, 'cfg://adict[2]') + self.assertRaises(KeyError, bc.convert, 'cfg://adict[alpha numeric ] 1 with spaces]') + self.assertRaises(ValueError, bc.convert, 'cfg://adict[ alpha ]] numeric 1 %( - © ©ß¯] ]') + self.assertRaises(ValueError, bc.convert, 'cfg://adict[ alpha [ numeric 1 %( - © ©ß¯] ]') def test_namedtuple(self): # see bpo-39142 @@ -3992,7 +4023,7 @@ def test_config_reject_simple_queue_handler_multiprocessing_context(self): @skip_if_tsan_fork @support.requires_subprocess() @unittest.skipUnless(support.Py_DEBUG, "requires a debug build for testing" - "assertions in multiprocessing") + " assertions in multiprocessing") def test_config_queue_handler_multiprocessing_context(self): # regression test for gh-121723 if support.MS_WINDOWS: @@ -4029,8 +4060,9 @@ def _mpinit_issue121723(qspec, message_to_log): # log a message (this creates a record put in the queue) logging.getLogger().info(message_to_log) - # TODO: RustPython + # TODO: RUSTPYTHON; ImportError: cannot import name 'SemLock' @unittest.expectedFailure + @skip_if_tsan_fork @support.requires_subprocess() def test_multiprocessing_queues(self): # See gh-119819 @@ -4089,7 +4121,7 @@ def test_90195(self): # Logger should be enabled, since explicitly mentioned self.assertFalse(logger.disabled) - # TODO: RustPython + # TODO: RUSTPYTHON; ImportError: cannot import name 'SemLock' @unittest.expectedFailure def test_111615(self): # See gh-111615 @@ -4139,6 +4171,91 @@ def __init__(self, *args, **kwargs): handler = logging.getHandlerByName('custom') self.assertEqual(handler.custom_kwargs, custom_kwargs) + # TODO: RUSTPYTHON; ImportError: cannot import name 'SemLock' + @unittest.expectedFailure + # See gh-91555 and gh-90321 + @support.requires_subprocess() + def test_deadlock_in_queue(self): + queue = multiprocessing.Queue() + handler = logging.handlers.QueueHandler(queue) + logger = multiprocessing.get_logger() + level = logger.level + try: + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + logger.debug("deadlock") + finally: + logger.setLevel(level) + logger.removeHandler(handler) + + def test_recursion_in_custom_handler(self): + class BadHandler(logging.Handler): + def __init__(self): + super().__init__() + def emit(self, record): + logger.debug("recurse") + logger = logging.getLogger("test_recursion_in_custom_handler") + logger.addHandler(BadHandler()) + logger.setLevel(logging.DEBUG) + logger.debug("boom") + + @threading_helper.requires_working_threading() + def test_thread_supression_noninterference(self): + lock = threading.Lock() + logger = logging.getLogger("test_thread_supression_noninterference") + + # Block on the first call, allow others through + # + # NOTE: We need to bypass the base class's lock, otherwise that will + # block multiple calls to the same handler itself. + class BlockOnceHandler(TestHandler): + def __init__(self, barrier): + super().__init__(support.Matcher()) + self.barrier = barrier + + def createLock(self): + self.lock = None + + def handle(self, record): + self.emit(record) + + def emit(self, record): + if self.barrier: + barrier = self.barrier + self.barrier = None + barrier.wait() + with lock: + pass + super().emit(record) + logger.info("blow up if not supressed") + + barrier = threading.Barrier(2) + handler = BlockOnceHandler(barrier) + logger.addHandler(handler) + logger.setLevel(logging.DEBUG) + + t1 = threading.Thread(target=logger.debug, args=("1",)) + with lock: + + # Ensure first thread is blocked in the handler, hence supressing logging... + t1.start() + barrier.wait() + + # ...but the second thread should still be able to log... + t2 = threading.Thread(target=logger.debug, args=("2",)) + t2.start() + t2.join(timeout=3) + + self.assertEqual(len(handler.buffer), 1) + self.assertTrue(handler.matches(levelno=logging.DEBUG, message='2')) + + # The first thread should still be blocked here + self.assertTrue(t1.is_alive()) + + # Now the lock has been released the first thread should complete + t1.join() + self.assertEqual(len(handler.buffer), 2) + self.assertTrue(handler.matches(levelno=logging.DEBUG, message='1')) class ManagerTest(BaseTest): def test_manager_loggerclass(self): @@ -4208,7 +4325,7 @@ def filter(self, record): t = type(record) if t is not self.cls: msg = 'Unexpected LogRecord type %s, expected %s' % (t, - self.cls) + self.cls) raise TypeError(msg) return True @@ -4228,7 +4345,7 @@ def test_logrecord_class(self): logging.setLogRecordFactory(DerivedLogRecord) self.root_logger.error(self.next_message()) self.assert_log_lines([ - ('root', 'ERROR', '2'), + ('root', 'ERROR', '2'), ]) @@ -4276,8 +4393,6 @@ def test_formatting(self): self.assertEqual(formatted_msg, log_record.msg) self.assertEqual(formatted_msg, log_record.message) - @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), - 'logging.handlers.QueueListener required for this test') def test_queue_listener(self): handler = TestHandler(support.Matcher()) listener = logging.handlers.QueueListener(self.queue, handler) @@ -4288,6 +4403,7 @@ def test_queue_listener(self): self.que_logger.critical(self.next_message()) finally: listener.stop() + listener.stop() # gh-114706 - ensure no crash if called again self.assertTrue(handler.matches(levelno=logging.WARNING, message='1')) self.assertTrue(handler.matches(levelno=logging.ERROR, message='2')) self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='3')) @@ -4311,8 +4427,18 @@ def test_queue_listener(self): self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='6')) handler.close() - @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), - 'logging.handlers.QueueListener required for this test') + # doesn't hurt to call stop() more than once. + listener.stop() + self.assertIsNone(listener._thread) + + def test_queue_listener_multi_start(self): + handler = TestHandler(support.Matcher()) + listener = logging.handlers.QueueListener(self.queue, handler) + listener.start() + self.assertRaises(RuntimeError, listener.start) + listener.stop() + self.assertIsNone(listener._thread) + def test_queue_listener_with_StreamHandler(self): # Test that traceback and stack-info only appends once (bpo-34334, bpo-46755). listener = logging.handlers.QueueListener(self.queue, self.root_hdlr) @@ -4327,8 +4453,6 @@ def test_queue_listener_with_StreamHandler(self): self.assertEqual(self.stream.getvalue().strip().count('Traceback'), 1) self.assertEqual(self.stream.getvalue().strip().count('Stack'), 1) - @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), - 'logging.handlers.QueueListener required for this test') def test_queue_listener_with_multiple_handlers(self): # Test that queue handler format doesn't affect other handler formats (bpo-35726). self.que_hdlr.setFormatter(self.root_formatter) @@ -4344,6 +4468,7 @@ def test_queue_listener_with_multiple_handlers(self): import multiprocessing from unittest.mock import patch + @skip_if_tsan_fork @threading_helper.requires_working_threading() class QueueListenerTest(BaseTest): """ @@ -4427,8 +4552,8 @@ def test_no_messages_in_queue_after_stop(self): expected = [[], [logging.handlers.QueueListener._sentinel]] self.assertIn(items, expected, 'Found unexpected messages in queue: %s' % ( - [m.msg if isinstance(m, logging.LogRecord) - else m for m in items])) + [m.msg if isinstance(m, logging.LogRecord) + else m for m in items])) def test_calls_task_done_after_stop(self): # Issue 36813: Make sure queue.join does not deadlock. @@ -4538,7 +4663,7 @@ def test_dollars(self): f = logging.Formatter('${asctime}--', style='$') self.assertTrue(f.usesTime()) - # TODO: RustPython + # TODO: RUSTPYTHON; ValueError: Unexpected error parsing format string @unittest.expectedFailure def test_format_validate(self): # Check correct formatting @@ -4713,7 +4838,7 @@ def test_defaults_parameter(self): def test_invalid_style(self): self.assertRaises(ValueError, logging.Formatter, None, None, 'x') - # TODO: RustPython + # TODO: RUSTPYTHON; AttributeError: 'struct_time' object has no attribute 'tm_gmtoff' @unittest.expectedFailure def test_time(self): r = self.get_record() @@ -4729,7 +4854,7 @@ def test_time(self): f.format(r) self.assertEqual(r.asctime, '1993-04-21 08:03:00,123') - # TODO: RustPython + # TODO: RUSTPYTHON; AttributeError: 'struct_time' object has no attribute 'tm_gmtoff' @unittest.expectedFailure def test_default_msec_format_none(self): class NoMsecFormatter(logging.Formatter): @@ -4751,6 +4876,77 @@ def test_issue_89047(self): s = f.format(r) self.assertNotIn('.1000', s) + def test_msecs_has_no_floating_point_precision_loss(self): + # See issue gh-102402 + tests = ( + # time_ns is approx. 2023-03-04 04:25:20 UTC + # (time_ns, expected_msecs_value) + (1_677_902_297_100_000_000, 100.0), # exactly 100ms + (1_677_903_920_999_998_503, 999.0), # check truncating doesn't round + (1_677_903_920_000_998_503, 0.0), # check truncating doesn't round + (1_677_903_920_999_999_900, 0.0), # check rounding up + ) + for ns, want in tests: + with patch('time.time_ns') as patched_ns: + patched_ns.return_value = ns + record = logging.makeLogRecord({'msg': 'test'}) + with self.subTest(ns): + self.assertEqual(record.msecs, want) + self.assertEqual(record.created, ns / 1e9) + self.assertAlmostEqual(record.created - int(record.created), + record.msecs / 1e3, + delta=1e-3) + + def test_relativeCreated_has_higher_precision(self): + # See issue gh-102402. + # Run the code in the subprocess, because the time module should + # be patched before the first import of the logging package. + # Temporary unloading and re-importing the logging package has + # side effects (including registering the atexit callback and + # references leak). + start_ns = 1_677_903_920_000_998_503 # approx. 2023-03-04 04:25:20 UTC + offsets_ns = (200, 500, 12_354, 99_999, 1_677_903_456_999_123_456) + code = textwrap.dedent(f""" + start_ns = {start_ns!r} + offsets_ns = {offsets_ns!r} + start_monotonic_ns = start_ns - 1 + + import time + # Only time.time_ns needs to be patched for the current + # implementation, but patch also other functions to make + # the test less implementation depending. + old_time_ns = time.time_ns + old_time = time.time + old_monotonic_ns = time.monotonic_ns + old_monotonic = time.monotonic + time_ns_result = start_ns + time.time_ns = lambda: time_ns_result + time.time = lambda: time.time_ns()/1e9 + time.monotonic_ns = lambda: time_ns_result - start_monotonic_ns + time.monotonic = lambda: time.monotonic_ns()/1e9 + try: + import logging + + for offset_ns in offsets_ns: + # mock for log record creation + time_ns_result = start_ns + offset_ns + record = logging.makeLogRecord({{'msg': 'test'}}) + print(record.created, record.relativeCreated) + finally: + time.time_ns = old_time_ns + time.time = old_time + time.monotonic_ns = old_monotonic_ns + time.monotonic = old_monotonic + """) + rc, out, err = assert_python_ok("-c", code) + out = out.decode() + for offset_ns, line in zip(offsets_ns, out.splitlines(), strict=True): + with self.subTest(offset_ns=offset_ns): + created, relativeCreated = map(float, line.split()) + self.assertAlmostEqual(created, (start_ns + offset_ns) / 1e9, places=6) + # After PR gh-102412, precision (places) increases from 3 to 7 + self.assertAlmostEqual(relativeCreated, offset_ns / 1e6, places=7) + class TestBufferingFormatter(logging.BufferingFormatter): def formatHeader(self, records): @@ -4795,9 +4991,9 @@ def test_formatting(self): self.assertTrue(r.exc_text.endswith('\nRuntimeError: ' 'deliberate mistake')) self.assertTrue(r.stack_info.startswith('Stack (most recent ' - 'call last):\n')) + 'call last):\n')) self.assertTrue(r.stack_info.endswith('logging.exception(\'failed\', ' - 'stack_info=True)')) + 'stack_info=True)')) class LastResortTest(BaseTest): @@ -5061,7 +5257,7 @@ def __init__(self, name='MyLogger', level=logging.NOTSET): h.close() logging.setLoggerClass(logging.Logger) - # TODO: RustPython + # TODO: RUSTPYTHON @unittest.expectedFailure def test_logging_at_shutdown(self): # bpo-20037: Doing text I/O late at interpreter shutdown must not crash @@ -5082,7 +5278,7 @@ def __del__(self): self.assertIn("exception in __del__", err) self.assertIn("ValueError: some error", err) - # TODO: RustPython + # TODO: RUSTPYTHON @unittest.expectedFailure def test_logging_at_shutdown_open(self): # bpo-26789: FileHandler keeps a reference to the builtin open() @@ -5175,7 +5371,7 @@ def _extract_logrecord_process_name(key, logMultiprocessing, conn=None): results = {'processName' : name, 'r1.processName': r1.processName, 'r2.processName': r2.processName, - } + } finally: logging.logMultiprocessing = prev_logMultiprocessing if conn: @@ -5183,6 +5379,7 @@ def _extract_logrecord_process_name(key, logMultiprocessing, conn=None): else: return results + @skip_if_tsan_fork def test_multiprocessing(self): support.skip_if_broken_multiprocessing_synchronize() multiprocessing_imported = 'multiprocessing' in sys.modules @@ -5333,7 +5530,7 @@ def test_strformatstyle(self): logging.error("Log an error") sys.stdout.seek(0) self.assertEqual(output.getvalue().strip(), - "ERROR:root:Log an error") + "ERROR:root:Log an error") def test_stringtemplatestyle(self): with support.captured_stdout() as output: @@ -5341,7 +5538,7 @@ def test_stringtemplatestyle(self): logging.error("Log an error") sys.stdout.seek(0) self.assertEqual(output.getvalue().strip(), - "ERROR:root:Log an error") + "ERROR:root:Log an error") def test_filename(self): @@ -5418,11 +5615,11 @@ def test_incompatible(self): handlers = [logging.StreamHandler()] stream = sys.stderr assertRaises(ValueError, logging.basicConfig, filename='test.log', - stream=stream) + stream=stream) assertRaises(ValueError, logging.basicConfig, filename='test.log', - handlers=handlers) + handlers=handlers) assertRaises(ValueError, logging.basicConfig, stream=stream, - handlers=handlers) + handlers=handlers) # Issue 23207: test for invalid kwargs assertRaises(ValueError, logging.basicConfig, loglevel=logging.INFO) # Should pop both filename and filemode even if filename is None @@ -5803,6 +6000,30 @@ def test_extra_not_merged_by_default(self): record = self.recording.records[0] self.assertFalse(hasattr(record, 'foo')) + def test_extra_merged(self): + self.adapter = logging.LoggerAdapter(logger=self.logger, + extra={'foo': '1'}, + merge_extra=True) + + self.adapter.critical('foo and bar should be here', extra={'bar': '2'}) + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertTrue(hasattr(record, 'foo')) + self.assertTrue(hasattr(record, 'bar')) + self.assertEqual(record.foo, '1') + self.assertEqual(record.bar, '2') + + def test_extra_merged_log_call_has_precedence(self): + self.adapter = logging.LoggerAdapter(logger=self.logger, + extra={'foo': '1'}, + merge_extra=True) + + self.adapter.critical('foo shall be min', extra={'foo': '2'}) + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertTrue(hasattr(record, 'foo')) + self.assertEqual(record.foo, '2') + class PrefixAdapter(logging.LoggerAdapter): prefix = 'Adapter' @@ -6112,14 +6333,14 @@ def test_should_not_rollover(self): # If maxBytes is zero rollover never occurs rh = logging.handlers.RotatingFileHandler( - self.fn, encoding="utf-8", maxBytes=0) + self.fn, encoding="utf-8", maxBytes=0) self.assertFalse(rh.shouldRollover(None)) rh.close() with open(self.fn, 'wb') as f: f.write(b'\n') rh = logging.handlers.RotatingFileHandler( - self.fn, encoding="utf-8", maxBytes=0) + self.fn, encoding="utf-8", maxBytes=0) self.assertFalse(rh.shouldRollover(None)) rh.close() @@ -6129,7 +6350,7 @@ def test_should_not_rollover_non_file(self): # We set maxBytes to 1 so that rollover would normally happen, except # for the check for regular files rh = logging.handlers.RotatingFileHandler( - os.devnull, encoding="utf-8", maxBytes=1) + os.devnull, encoding="utf-8", maxBytes=1) self.assertFalse(rh.shouldRollover(self.next_rec())) rh.close() @@ -6260,12 +6481,12 @@ def rotator(source, dest): class TimedRotatingFileHandlerTest(BaseFileTest): # TODO: RUSTPYTHON - @unittest.skip("OS dependent bug") + @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") @unittest.skipIf(support.is_wasi, "WASI does not have /dev/null.") def test_should_not_rollover(self): # See bpo-45401. Should only ever rollover regular files fh = logging.handlers.TimedRotatingFileHandler( - os.devnull, 'S', encoding="utf-8", backupCount=1) + os.devnull, 'S', encoding="utf-8", backupCount=1) time.sleep(1.1) # a little over a second ... r = logging.makeLogRecord({'msg': 'testing - device file'}) self.assertFalse(fh.shouldRollover(r)) @@ -6274,7 +6495,7 @@ def test_should_not_rollover(self): # other test methods added below def test_rollover(self): fh = logging.handlers.TimedRotatingFileHandler( - self.fn, 'S', encoding="utf-8", backupCount=1) + self.fn, 'S', encoding="utf-8", backupCount=1) fmt = logging.Formatter('%(asctime)s %(message)s') fh.setFormatter(fmt) r1 = logging.makeLogRecord({'msg': 'testing - initial'}) @@ -6314,8 +6535,8 @@ def test_rollover(self): print(tf.read()) self.assertTrue(found, msg=msg) - # TODO: RustPython - @unittest.skip("OS dependent bug") + # TODO: RUSTPYTHON + @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_rollover_at_midnight(self, weekly=False): os_helper.unlink(self.fn) now = datetime.datetime.now() @@ -6359,8 +6580,8 @@ def test_rollover_at_midnight(self, weekly=False): for i, line in enumerate(f): self.assertIn(f'testing1 {i}', line) - # TODO: RustPython - @unittest.skip("OS dependent bug") + # TODO: RUSTPYTHON + @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_rollover_at_weekday(self): self.test_rollover_at_midnight(weekly=True) @@ -6944,7 +7165,7 @@ def secs(**kw): ('MIDNIGHT', 60 * 60 * 24), # current time (epoch start) is a Thursday, W0 means Monday ('W0', secs(days=4, hours=24)), - ): + ): for interval in 1, 3: def test_compute_rollover(self, when=when, interval=interval, exp=exp): rh = logging.handlers.TimedRotatingFileHandler( @@ -6970,8 +7191,8 @@ def test_compute_rollover(self, when=when, interval=interval, exp=exp): currentSecond = t[5] # r is the number of seconds left between now and midnight r = logging.handlers._MIDNIGHT - ((currentHour * 60 + - currentMinute) * 60 + - currentSecond) + currentMinute) * 60 + + currentSecond) result = currentTime + r print('t: %s (%s)' % (t, rh.utc), file=sys.stderr) print('currentHour: %s' % currentHour, file=sys.stderr) diff --git a/Lib/test/test_multiprocessing_fork/__init__.py b/Lib/test/test_multiprocessing_fork/__init__.py index aa1fff50b2..b35e82879d 100644 --- a/Lib/test/test_multiprocessing_fork/__init__.py +++ b/Lib/test/test_multiprocessing_fork/__init__.py @@ -12,5 +12,8 @@ if sys.platform == 'darwin': raise unittest.SkipTest("test may crash on macOS (bpo-33725)") +if support.check_sanitizer(thread=True): + raise unittest.SkipTest("TSAN doesn't support threads after fork") + def load_tests(*args): return support.load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_property.py b/Lib/test/test_property.py index 8411e903b1..340f79b843 100644 --- a/Lib/test/test_property.py +++ b/Lib/test/test_property.py @@ -242,8 +242,6 @@ class PropertySubSlots(property): class PropertySubclassTests(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_slots_docstring_copy_exception(self): try: class Foo(object): diff --git a/Lib/test/test_selectors.py b/Lib/test/test_selectors.py index af730a8dec..59c1d26a7c 100644 --- a/Lib/test/test_selectors.py +++ b/Lib/test/test_selectors.py @@ -6,8 +6,7 @@ import socket import sys from test import support -from test.support import os_helper -from test.support import socket_helper +from test.support import is_apple, os_helper, socket_helper from time import sleep import unittest import unittest.mock @@ -132,6 +131,7 @@ def test_unregister_after_fd_close_and_reuse(self): s.unregister(r) s.unregister(w) + # TODO: RUSTPYTHON @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_unregister_after_socket_close(self): s = self.SELECTOR() @@ -224,6 +224,8 @@ def test_close(self): self.assertRaises(RuntimeError, s.get_key, wr) self.assertRaises(KeyError, mapping.__getitem__, rd) self.assertRaises(KeyError, mapping.__getitem__, wr) + self.assertEqual(mapping.get(rd), None) + self.assertEqual(mapping.get(wr), None) def test_get_key(self): s = self.SELECTOR() @@ -242,13 +244,17 @@ def test_get_map(self): self.addCleanup(s.close) rd, wr = self.make_socketpair() + sentinel = object() keys = s.get_map() self.assertFalse(keys) self.assertEqual(len(keys), 0) self.assertEqual(list(keys), []) + self.assertEqual(keys.get(rd), None) + self.assertEqual(keys.get(rd, sentinel), sentinel) key = s.register(rd, selectors.EVENT_READ, "data") self.assertIn(rd, keys) + self.assertEqual(key, keys.get(rd)) self.assertEqual(key, keys[rd]) self.assertEqual(len(keys), 1) self.assertEqual(list(keys), [rd.fileno()]) @@ -521,7 +527,7 @@ def test_above_fd_setsize(self): try: fds = s.select() except OSError as e: - if e.errno == errno.EINVAL and sys.platform == 'darwin': + if e.errno == errno.EINVAL and is_apple: # unexplainable errors on macOS don't need to fail the test self.skipTest("Invalid argument error calling poll()") raise diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 56cd8f8a68..3ef18111bc 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -769,8 +769,6 @@ def test_execute_illegal_sql(self): with self.assertRaises(sqlite.OperationalError): self.cu.execute("select asdf") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_execute_multiple_statements(self): msg = "You can only execute one statement at a time" dataset = ( @@ -793,8 +791,6 @@ def test_execute_multiple_statements(self): with self.assertRaisesRegex(sqlite.ProgrammingError, msg): self.cu.execute(query) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_execute_with_appended_comments(self): dataset = ( "select 1; -- foo bar", @@ -963,8 +959,6 @@ def test_rowcount_update_returning(self): self.assertEqual(self.cu.fetchone()[0], 1) self.assertEqual(self.cu.rowcount, 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_rowcount_prefixed_with_comment(self): # gh-79579: rowcount is updated even if query is prefixed with comments self.cu.execute(""" diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py index b6e9c24a80..ec4cca4528 100644 --- a/Lib/test/test_stat.py +++ b/Lib/test/test_stat.py @@ -2,8 +2,7 @@ import os import socket import sys -from test.support import os_helper -from test.support import socket_helper +from test.support import is_apple, os_helper, socket_helper from test.support.import_helper import import_fresh_module from test.support.os_helper import TESTFN @@ -15,8 +14,10 @@ class TestFilemode: statmod = None file_flags = {'SF_APPEND', 'SF_ARCHIVED', 'SF_IMMUTABLE', 'SF_NOUNLINK', - 'SF_SNAPSHOT', 'UF_APPEND', 'UF_COMPRESSED', 'UF_HIDDEN', - 'UF_IMMUTABLE', 'UF_NODUMP', 'UF_NOUNLINK', 'UF_OPAQUE'} + 'SF_SNAPSHOT', 'SF_SETTABLE', 'SF_RESTRICTED', 'SF_FIRMLINK', + 'SF_DATALESS', 'UF_APPEND', 'UF_COMPRESSED', 'UF_HIDDEN', + 'UF_IMMUTABLE', 'UF_NODUMP', 'UF_NOUNLINK', 'UF_OPAQUE', + 'UF_SETTABLE', 'UF_TRACKED', 'UF_DATAVAULT'} formats = {'S_IFBLK', 'S_IFCHR', 'S_IFDIR', 'S_IFIFO', 'S_IFLNK', 'S_IFREG', 'S_IFSOCK', 'S_IFDOOR', 'S_IFPORT', 'S_IFWHT'} @@ -113,6 +114,7 @@ def assertS_IS(self, name, mode): else: self.assertFalse(func(mode)) + @os_helper.skip_unless_working_chmod def test_mode(self): with open(TESTFN, 'w'): pass @@ -121,8 +123,11 @@ def test_mode(self): st_mode, modestr = self.get_mode() self.assertEqual(modestr, '-rwx------') self.assertS_IS("REG", st_mode) - self.assertEqual(self.statmod.S_IMODE(st_mode), + imode = self.statmod.S_IMODE(st_mode) + self.assertEqual(imode, self.statmod.S_IRWXU) + self.assertEqual(self.statmod.filemode(imode), + '?rwx------') os.chmod(TESTFN, 0o070) st_mode, modestr = self.get_mode() @@ -144,13 +149,21 @@ def test_mode(self): self.assertEqual(modestr, '-r--r--r--') self.assertEqual(self.statmod.S_IMODE(st_mode), 0o444) else: + os.chmod(TESTFN, 0o500) + st_mode, modestr = self.get_mode() + self.assertEqual(modestr[:3], '-r-') + self.assertS_IS("REG", st_mode) + self.assertEqual(self.statmod.S_IMODE(st_mode), 0o444) + os.chmod(TESTFN, 0o700) st_mode, modestr = self.get_mode() self.assertEqual(modestr[:3], '-rw') self.assertS_IS("REG", st_mode) self.assertEqual(self.statmod.S_IFMT(st_mode), self.statmod.S_IFREG) + self.assertEqual(self.statmod.S_IMODE(st_mode), 0o666) + @os_helper.skip_unless_working_chmod def test_directory(self): os.mkdir(TESTFN) os.chmod(TESTFN, 0o700) @@ -161,7 +174,7 @@ def test_directory(self): else: self.assertEqual(modestr[0], 'd') - @unittest.skipUnless(hasattr(os, 'symlink'), 'os.symlink not available') + @os_helper.skip_unless_symlink def test_link(self): try: os.symlink(os.getcwd(), TESTFN) @@ -227,6 +240,18 @@ def test_module_attributes(self): self.assertTrue(callable(func)) self.assertEqual(func(0), 0) + def test_flags_consistent(self): + self.assertFalse(self.statmod.UF_SETTABLE & self.statmod.SF_SETTABLE) + + for flag in self.file_flags: + if flag.startswith("UF"): + self.assertTrue(getattr(self.statmod, flag) & self.statmod.UF_SETTABLE, f"{flag} not in UF_SETTABLE") + elif is_apple and self.statmod is c_stat and flag == 'SF_DATALESS': + self.assertTrue(self.statmod.SF_DATALESS & self.statmod.SF_SYNTHETIC, "SF_DATALESS not in SF_SYNTHETIC") + self.assertFalse(self.statmod.SF_DATALESS & self.statmod.SF_SETTABLE, "SF_DATALESS in SF_SETTABLE") + else: + self.assertTrue(getattr(self.statmod, flag) & self.statmod.SF_SETTABLE, f"{flag} notin SF_SETTABLE") + @unittest.skipUnless(sys.platform == "win32", "FILE_ATTRIBUTE_* constants are Win32 specific") def test_file_attribute_constants(self): @@ -235,7 +260,68 @@ def test_file_attribute_constants(self): modvalue = getattr(self.statmod, key) self.assertEqual(value, modvalue, key) + @unittest.skipUnless(sys.platform == "darwin", "macOS system check") + def test_macosx_attribute_values(self): + self.assertEqual(self.statmod.UF_SETTABLE, 0x0000ffff) + self.assertEqual(self.statmod.UF_NODUMP, 0x00000001) + self.assertEqual(self.statmod.UF_IMMUTABLE, 0x00000002) + self.assertEqual(self.statmod.UF_APPEND, 0x00000004) + self.assertEqual(self.statmod.UF_OPAQUE, 0x00000008) + self.assertEqual(self.statmod.UF_COMPRESSED, 0x00000020) + self.assertEqual(self.statmod.UF_TRACKED, 0x00000040) + self.assertEqual(self.statmod.UF_DATAVAULT, 0x00000080) + self.assertEqual(self.statmod.UF_HIDDEN, 0x00008000) + + if self.statmod is c_stat: + self.assertEqual(self.statmod.SF_SUPPORTED, 0x009f0000) + self.assertEqual(self.statmod.SF_SETTABLE, 0x3fff0000) + self.assertEqual(self.statmod.SF_SYNTHETIC, 0xc0000000) + else: + self.assertEqual(self.statmod.SF_SETTABLE, 0xffff0000) + self.assertEqual(self.statmod.SF_ARCHIVED, 0x00010000) + self.assertEqual(self.statmod.SF_IMMUTABLE, 0x00020000) + self.assertEqual(self.statmod.SF_APPEND, 0x00040000) + self.assertEqual(self.statmod.SF_RESTRICTED, 0x00080000) + self.assertEqual(self.statmod.SF_NOUNLINK, 0x00100000) + self.assertEqual(self.statmod.SF_FIRMLINK, 0x00800000) + self.assertEqual(self.statmod.SF_DATALESS, 0x40000000) + + self.assertFalse(isinstance(self.statmod.S_IFMT, int)) + self.assertEqual(self.statmod.S_IFIFO, 0o010000) + self.assertEqual(self.statmod.S_IFCHR, 0o020000) + self.assertEqual(self.statmod.S_IFDIR, 0o040000) + self.assertEqual(self.statmod.S_IFBLK, 0o060000) + self.assertEqual(self.statmod.S_IFREG, 0o100000) + self.assertEqual(self.statmod.S_IFLNK, 0o120000) + self.assertEqual(self.statmod.S_IFSOCK, 0o140000) + + if self.statmod is c_stat: + self.assertEqual(self.statmod.S_IFWHT, 0o160000) + + self.assertEqual(self.statmod.S_IRWXU, 0o000700) + self.assertEqual(self.statmod.S_IRUSR, 0o000400) + self.assertEqual(self.statmod.S_IWUSR, 0o000200) + self.assertEqual(self.statmod.S_IXUSR, 0o000100) + self.assertEqual(self.statmod.S_IRWXG, 0o000070) + self.assertEqual(self.statmod.S_IRGRP, 0o000040) + self.assertEqual(self.statmod.S_IWGRP, 0o000020) + self.assertEqual(self.statmod.S_IXGRP, 0o000010) + self.assertEqual(self.statmod.S_IRWXO, 0o000007) + self.assertEqual(self.statmod.S_IROTH, 0o000004) + self.assertEqual(self.statmod.S_IWOTH, 0o000002) + self.assertEqual(self.statmod.S_IXOTH, 0o000001) + self.assertEqual(self.statmod.S_ISUID, 0o004000) + self.assertEqual(self.statmod.S_ISGID, 0o002000) + self.assertEqual(self.statmod.S_ISVTX, 0o001000) + + self.assertFalse(hasattr(self.statmod, "S_ISTXT")) + self.assertEqual(self.statmod.S_IREAD, self.statmod.S_IRUSR) + self.assertEqual(self.statmod.S_IWRITE, self.statmod.S_IWUSR) + self.assertEqual(self.statmod.S_IEXEC, self.statmod.S_IXUSR) + + +@unittest.skipIf(c_stat is None, 'need _stat extension') class TestFilemodeCStat(TestFilemode, unittest.TestCase): statmod = c_stat diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index 242c076f9b..e7cd5962cf 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -1132,8 +1132,6 @@ class MyRef(weakref.ref): self.assertIn(r1, refs) self.assertIn(r2, refs) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_subclass_refs_with_slots(self): class MyRef(weakref.ref): __slots__ = "slot1", "slot2" diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py index 6c4cb05a5f..5e8cacc09d 100644 --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -44,7 +44,7 @@ def setUp(self): def test_methods(self): weaksetmethods = dir(WeakSet) for method in dir(set): - if method == 'test_c_api' or method.startswith('_'): + if method.startswith('_'): continue self.assertIn(method, weaksetmethods, "WeakSet missing method " + method) @@ -458,8 +458,6 @@ def test_abc(self): self.assertIsInstance(self.s, Set) self.assertIsInstance(self.s, MutableSet) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_copying(self): for cls in WeakSet, WeakSetWithSlots: s = cls(self.items) diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 62c8508cbf..59c7a28932 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -351,6 +351,7 @@ impl<'src> Compiler<'src> { static_attributes: None, in_inlined_comp: false, fblock: Vec::with_capacity(MAXBLOCKS), + symbol_table_index: 0, // Module is always the first symbol table }; Compiler { code_stack: vec![module_code], @@ -384,15 +385,76 @@ impl Compiler<'_> { } } + /// Get the SymbolTable for the current scope. + fn current_symbol_table(&self) -> &SymbolTable { + if self.symbol_table_stack.is_empty() { + panic!("symbol_table_stack is empty! This is a compiler bug."); + } + let index = self.symbol_table_stack.len() - 1; + &self.symbol_table_stack[index] + } + + /// Get the index of a free variable. + fn get_free_var_index(&mut self, name: &str) -> CompileResult { + let info = self.code_stack.last_mut().unwrap(); + let idx = info + .metadata + .freevars + .get_index_of(name) + .unwrap_or_else(|| info.metadata.freevars.insert_full(name.to_owned()).0); + Ok((idx + info.metadata.cellvars.len()).to_u32()) + } + + /// Get the index of a cell variable. + fn get_cell_var_index(&mut self, name: &str) -> CompileResult { + let info = self.code_stack.last_mut().unwrap(); + let idx = info + .metadata + .cellvars + .get_index_of(name) + .unwrap_or_else(|| info.metadata.cellvars.insert_full(name.to_owned()).0); + Ok(idx.to_u32()) + } + + /// Get the index of a local variable. + fn get_local_var_index(&mut self, name: &str) -> CompileResult { + let info = self.code_stack.last_mut().unwrap(); + let idx = info + .metadata + .varnames + .get_index_of(name) + .unwrap_or_else(|| info.metadata.varnames.insert_full(name.to_owned()).0); + Ok(idx.to_u32()) + } + + /// Get the index of a global name. + fn get_global_name_index(&mut self, name: &str) -> u32 { + let info = self.code_stack.last_mut().unwrap(); + let idx = info + .metadata + .names + .get_index_of(name) + .unwrap_or_else(|| info.metadata.names.insert_full(name.to_owned()).0); + idx.to_u32() + } + /// Push the next symbol table on to the stack fn push_symbol_table(&mut self) -> &SymbolTable { // Look up the next table contained in the scope of the current table - let table = self + let current_table = self .symbol_table_stack .last_mut() - .expect("no next symbol table") - .sub_tables - .remove(0); + .expect("no current symbol table"); + + if current_table.sub_tables.is_empty() { + panic!( + "push_symbol_table: no sub_tables available in {} (type: {:?})", + current_table.name, current_table.typ + ); + } + + let table = current_table.sub_tables.remove(0); + // Push the next table onto the stack let last_idx = self.symbol_table_stack.len(); self.symbol_table_stack.push(table); @@ -537,6 +599,7 @@ impl Compiler<'_> { }, in_inlined_comp: false, fblock: Vec::with_capacity(MAXBLOCKS), + symbol_table_index: key, }; // Push the old compiler unit on the stack (like PyCapsule) @@ -615,10 +678,17 @@ impl Compiler<'_> { // compiler_exit_scope fn exit_scope(&mut self) -> CodeObject { - let table = self.pop_symbol_table(); - assert!(table.sub_tables.is_empty()); + let _table = self.pop_symbol_table(); + + // Various scopes can have sub_tables: + // - TypeParams scope can have sub_tables (the function body's symbol table) + // - Module scope can have sub_tables (for TypeAlias scopes, nested functions, classes) + // - Function scope can have sub_tables (for nested functions, classes) + // - Class scope can have sub_tables (for nested classes, methods) + let pop = self.code_stack.pop(); let stack_top = compiler_unwrap_option(self, pop); + // No parent scope stack to maintain unwrap_internal(self, stack_top.finalize_code(self.opts.optimize)) } @@ -679,7 +749,8 @@ impl Compiler<'_> { .to_u32() } - /// Set the qualified name for the current code object, based on CPython's compiler_set_qualname + /// Set the qualified name for the current code object + // = compiler_set_qualname fn set_qualname(&mut self) -> String { let qualname = self.make_qualname(); self.current_code_info().metadata.qualname = Some(qualname.clone()); @@ -700,13 +771,14 @@ impl Compiler<'_> { let mut parent_idx = stack_size - 2; let mut parent = &self.code_stack[parent_idx]; - // If parent is a type parameter scope, look at grandparent + // If parent is TypeParams scope, look at grandparent + // Check if parent is a type params scope by name pattern if parent.metadata.name.starts_with(" { Err(self.error(CodegenErrorType::SyntaxError(format!("{msg} {name}")))) } + // = compiler_nameop fn compile_name(&mut self, name: &str, usage: NameUsage) -> CompileResult<()> { - let name = self.mangle(name); - - self.check_forbidden_name(&name, usage)?; - - let symbol_table = self.symbol_table_stack.last().unwrap(); - let symbol = unwrap_internal( - self, - symbol_table - .lookup(name.as_ref()) - .ok_or_else(|| InternalError::MissingSymbol(name.to_string())), - ); - let info = self.code_stack.last_mut().unwrap(); - let mut cache = &mut info.metadata.names; - enum NameOpType { + enum NameOp { Fast, Global, Deref, - Local, - } - let op_typ = match symbol.scope { - SymbolScope::Local if self.ctx.in_func() => { - cache = &mut info.metadata.varnames; - NameOpType::Fast - } - SymbolScope::GlobalExplicit => NameOpType::Global, - SymbolScope::GlobalImplicit | SymbolScope::Unknown if self.ctx.in_func() => { - NameOpType::Global - } - SymbolScope::GlobalImplicit | SymbolScope::Unknown => NameOpType::Local, - SymbolScope::Local => NameOpType::Local, - SymbolScope::Free => { - cache = &mut info.metadata.freevars; - NameOpType::Deref - } - SymbolScope::Cell => { - cache = &mut info.metadata.cellvars; - NameOpType::Deref - } // TODO: is this right? - // SymbolScope::Unknown => NameOpType::Global, - }; + Name, + } + + let name = self.mangle(name); + self.check_forbidden_name(&name, usage)?; + // Special handling for __debug__ if NameUsage::Load == usage && name == "__debug__" { self.emit_load_const(ConstantData::Boolean { value: self.opts.optimize == 0, @@ -960,38 +1003,111 @@ impl Compiler<'_> { return Ok(()); } - let mut idx = cache - .get_index_of(name.as_ref()) - .unwrap_or_else(|| cache.insert_full(name.into_owned()).0); - if let SymbolScope::Free = symbol.scope { - idx += info.metadata.cellvars.len(); - } - let op = match op_typ { - NameOpType::Fast => match usage { - NameUsage::Load => Instruction::LoadFast, - NameUsage::Store => Instruction::StoreFast, - NameUsage::Delete => Instruction::DeleteFast, - }, - NameOpType::Global => match usage { - NameUsage::Load => Instruction::LoadGlobal, - NameUsage::Store => Instruction::StoreGlobal, - NameUsage::Delete => Instruction::DeleteGlobal, - }, - NameOpType::Deref => match usage { - NameUsage::Load if !self.ctx.in_func() && self.ctx.in_class => { - Instruction::LoadClassDeref + // Determine the operation type based on symbol scope + let is_function_like = self.ctx.in_func(); + + // Look up the symbol, handling TypeParams scope specially + let (symbol_scope, _is_typeparams) = { + let current_table = self.current_symbol_table(); + let is_typeparams = current_table.typ == CompilerScope::TypeParams; + + // First try to find in current table + let symbol = current_table.lookup(name.as_ref()); + + // If not found and we're in TypeParams scope, try parent scope + let symbol = if symbol.is_none() && is_typeparams { + if self.symbol_table_stack.len() > 1 { + let parent_idx = self.symbol_table_stack.len() - 2; + self.symbol_table_stack[parent_idx].lookup(name.as_ref()) + } else { + None } - NameUsage::Load => Instruction::LoadDeref, - NameUsage::Store => Instruction::StoreDeref, - NameUsage::Delete => Instruction::DeleteDeref, - }, - NameOpType::Local => match usage { - NameUsage::Load => Instruction::LoadNameAny, - NameUsage::Store => Instruction::StoreLocal, - NameUsage::Delete => Instruction::DeleteLocal, - }, + } else { + symbol + }; + + (symbol.map(|s| s.scope), is_typeparams) + }; + + let actual_scope = symbol_scope.ok_or_else(|| { + self.error(CodegenErrorType::SyntaxError(format!( + "The symbol '{name}' must be present in the symbol table" + ))) + })?; + + // Determine operation type based on scope + let op_type = match actual_scope { + SymbolScope::Free => NameOp::Deref, + SymbolScope::Cell => NameOp::Deref, + SymbolScope::Local => { + if is_function_like { + NameOp::Fast + } else { + NameOp::Name + } + } + SymbolScope::GlobalImplicit => { + if is_function_like { + NameOp::Global + } else { + NameOp::Name + } + } + SymbolScope::GlobalExplicit => NameOp::Global, + SymbolScope::Unknown => NameOp::Name, }; - self.emit_arg(idx.to_u32(), op); + + // Generate appropriate instructions based on operation type + match op_type { + NameOp::Deref => { + let idx = match actual_scope { + SymbolScope::Free => self.get_free_var_index(&name)?, + SymbolScope::Cell => self.get_cell_var_index(&name)?, + _ => unreachable!("Invalid scope for Deref operation"), + }; + + let op = match usage { + NameUsage::Load => { + // Special case for class scope + if self.ctx.in_class && !self.ctx.in_func() { + Instruction::LoadClassDeref + } else { + Instruction::LoadDeref + } + } + NameUsage::Store => Instruction::StoreDeref, + NameUsage::Delete => Instruction::DeleteDeref, + }; + self.emit_arg(idx, op); + } + NameOp::Fast => { + let idx = self.get_local_var_index(&name)?; + let op = match usage { + NameUsage::Load => Instruction::LoadFast, + NameUsage::Store => Instruction::StoreFast, + NameUsage::Delete => Instruction::DeleteFast, + }; + self.emit_arg(idx, op); + } + NameOp::Global => { + let idx = self.get_global_name_index(&name); + let op = match usage { + NameUsage::Load => Instruction::LoadGlobal, + NameUsage::Store => Instruction::StoreGlobal, + NameUsage::Delete => Instruction::DeleteGlobal, + }; + self.emit_arg(idx, op); + } + NameOp::Name => { + let idx = self.get_global_name_index(&name); + let op = match usage { + NameUsage::Load => Instruction::LoadNameAny, + NameUsage::Store => Instruction::StoreLocal, + NameUsage::Delete => Instruction::DeleteLocal, + }; + self.emit_arg(idx, op); + } + } Ok(()) } @@ -1414,6 +1530,7 @@ impl Compiler<'_> { }); if let Some(type_params) = type_params { + // For TypeAlias, we need to use push_symbol_table to properly handle the TypeAlias scope self.push_symbol_table(); // Compile type params and push to stack @@ -1424,9 +1541,10 @@ impl Compiler<'_> { self.compile_expression(value)?; // Stack: [name, type_params_tuple, value] + // Pop the TypeAlias scope self.pop_symbol_table(); } else { - // Push None for type_params (matching CPython) + // Push None for type_params self.emit_load_const(ConstantData::None); // Stack: [name, None] @@ -1735,7 +1853,7 @@ impl Compiler<'_> { // Delete the exception variable if it was bound if let Some(alias) = name { - // Set the variable to None before deleting (as CPython does) + // Set the variable to None before deleting self.emit_load_const(ConstantData::None); self.store_name(alias.as_str())?; self.compile_name(alias.as_str(), NameUsage::Delete)?; @@ -1803,43 +1921,36 @@ impl Compiler<'_> { is_forbidden_name(name) } - #[allow(clippy::too_many_arguments)] - fn compile_function_def( + /// Compile default arguments + // = compiler_default_arguments + fn compile_default_arguments( &mut self, - name: &str, parameters: &Parameters, - body: &[Stmt], - decorator_list: &[Decorator], - returns: Option<&Expr>, // TODO: use type hint somehow.. - is_async: bool, - type_params: Option<&TypeParams>, - ) -> CompileResult<()> { - self.prepare_decorators(decorator_list)?; - - // If there are type params, we need to push a special symbol table just for them - if type_params.is_some() { - self.push_symbol_table(); - } + ) -> CompileResult { + let mut funcflags = bytecode::MakeFunctionFlags::empty(); - // Prepare defaults and kwdefaults before entering function + // Handle positional defaults let defaults: Vec<_> = std::iter::empty() .chain(¶meters.posonlyargs) .chain(¶meters.args) .filter_map(|x| x.default.as_deref()) .collect(); - let have_defaults = !defaults.is_empty(); - // Compile defaults before entering function scope - if have_defaults { - // Construct a tuple: - let size = defaults.len().to_u32(); - for element in &defaults { - self.compile_expression(element)?; + if !defaults.is_empty() { + // Compile defaults and build tuple + for default in &defaults { + self.compile_expression(default)?; } - emit!(self, Instruction::BuildTuple { size }); + emit!( + self, + Instruction::BuildTuple { + size: defaults.len().to_u32() + } + ); + funcflags |= bytecode::MakeFunctionFlags::DEFAULTS; } - // Prepare keyword-only defaults + // Handle keyword-only defaults let mut kw_with_defaults = vec![]; for kwonlyarg in ¶meters.kwonlyargs { if let Some(default) = &kwonlyarg.default { @@ -1847,10 +1958,9 @@ impl Compiler<'_> { } } - let have_kwdefaults = !kw_with_defaults.is_empty(); - if have_kwdefaults { - let default_kw_count = kw_with_defaults.len(); - for (arg, default) in kw_with_defaults.iter() { + if !kw_with_defaults.is_empty() { + // Compile kwdefaults and build dict + for (arg, default) in &kw_with_defaults { self.emit_load_const(ConstantData::Str { value: arg.name.as_str().into(), }); @@ -1859,26 +1969,33 @@ impl Compiler<'_> { emit!( self, Instruction::BuildMap { - size: default_kw_count.to_u32(), + size: kw_with_defaults.len().to_u32(), } ); + funcflags |= bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS; } + Ok(funcflags) + } + + /// Compile function body and create function object + // = compiler_function_body + fn compile_function_body( + &mut self, + name: &str, + parameters: &Parameters, + body: &[Stmt], + is_async: bool, + funcflags: bytecode::MakeFunctionFlags, + ) -> CompileResult<()> { + // Always enter function scope self.enter_function(name, parameters)?; - let mut func_flags = bytecode::MakeFunctionFlags::empty(); - if have_defaults { - func_flags |= bytecode::MakeFunctionFlags::DEFAULTS; - } - if have_kwdefaults { - func_flags |= bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS; - } self.current_code_info() .flags .set(bytecode::CodeFlags::IS_COROUTINE, is_async); - // remember to restore self.ctx.in_loop to the original after the function is compiled + // Set up context let prev_ctx = self.ctx; - self.ctx = CompileContext { loop_data: None, in_class: prev_ctx.in_class, @@ -1889,51 +2006,67 @@ impl Compiler<'_> { }, }; - // Set qualname using the new method + // Set qualname self.set_qualname(); + // Handle docstring let (doc_str, body) = split_doc(body, &self.opts); - self.current_code_info() .metadata .consts .insert_full(ConstantData::None); + // Compile body statements self.compile_statements(body)?; - // Emit None at end: + // Emit None at end if needed match body.last() { - Some(Stmt::Return(_)) => { - // the last instruction is a ReturnValue already, we don't need to emit it - } + Some(Stmt::Return(_)) => {} _ => { self.emit_return_const(ConstantData::None); } } + // Exit scope and create function object let code = self.exit_scope(); self.ctx = prev_ctx; - // Prepare generic type parameters: - if let Some(type_params) = type_params { - self.compile_type_params(type_params)?; - func_flags |= bytecode::MakeFunctionFlags::TYPE_PARAMS; + // Create function object with closure + self.make_closure(code, funcflags)?; + + // Handle docstring if present + if let Some(doc) = doc_str { + emit!(self, Instruction::Duplicate); + self.emit_load_const(ConstantData::Str { + value: doc.to_string().into(), + }); + emit!(self, Instruction::Rotate2); + let doc_attr = self.name("__doc__"); + emit!(self, Instruction::StoreAttr { idx: doc_attr }); } - // Prepare type annotations: + Ok(()) + } + + /// Compile function annotations + // = compiler_visit_annotations + fn visit_annotations( + &mut self, + parameters: &Parameters, + returns: Option<&Expr>, + ) -> CompileResult { let mut num_annotations = 0; - // Return annotation: + // Handle return annotation first if let Some(annotation) = returns { - // key: self.emit_load_const(ConstantData::Str { value: "return".into(), }); - // value: self.compile_annotation(annotation)?; num_annotations += 1; } + // Handle parameter annotations let parameters_iter = std::iter::empty() .chain(¶meters.posonlyargs) .chain(¶meters.args) @@ -1941,6 +2074,7 @@ impl Compiler<'_> { .map(|x| &x.parameter) .chain(parameters.vararg.as_deref()) .chain(parameters.kwarg.as_deref()); + for param in parameters_iter { if let Some(annotation) = ¶m.annotation { self.emit_load_const(ConstantData::Str { @@ -1951,8 +2085,83 @@ impl Compiler<'_> { } } + Ok(num_annotations) + } + + // = compiler_function + #[allow(clippy::too_many_arguments)] + fn compile_function_def( + &mut self, + name: &str, + parameters: &Parameters, + body: &[Stmt], + decorator_list: &[Decorator], + returns: Option<&Expr>, // TODO: use type hint somehow.. + is_async: bool, + type_params: Option<&TypeParams>, + ) -> CompileResult<()> { + self.prepare_decorators(decorator_list)?; + + // compile defaults and return funcflags + let funcflags = self.compile_default_arguments(parameters)?; + + let is_generic = type_params.is_some(); + let mut num_typeparam_args = 0; + + if is_generic { + // Count args to pass to type params scope + if funcflags.contains(bytecode::MakeFunctionFlags::DEFAULTS) { + num_typeparam_args += 1; + } + if funcflags.contains(bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS) { + num_typeparam_args += 1; + } + + // SWAP if we have both + if num_typeparam_args == 2 { + emit!(self, Instruction::Swap { index: 2 }); + } + + // Enter type params scope + let type_params_name = format!(""); + self.push_output( + bytecode::CodeFlags::IS_OPTIMIZED | bytecode::CodeFlags::NEW_LOCALS, + 0, + num_typeparam_args as u32, + 0, + type_params_name, + ); + + // Add parameter names to varnames for the type params scope + // These will be passed as arguments when the closure is called + let current_info = self.current_code_info(); + if funcflags.contains(bytecode::MakeFunctionFlags::DEFAULTS) { + current_info + .metadata + .varnames + .insert(".defaults".to_owned()); + } + if funcflags.contains(bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS) { + current_info + .metadata + .varnames + .insert(".kwdefaults".to_owned()); + } + + // Compile type parameters + self.compile_type_params(type_params.unwrap())?; + + // Load defaults/kwdefaults with LOAD_FAST + for i in 0..num_typeparam_args { + emit!(self, Instruction::LoadFast(i as u32)); + } + } + + // Compile annotations + let mut annotations_flag = bytecode::MakeFunctionFlags::empty(); + let num_annotations = self.visit_annotations(parameters, returns)?; if num_annotations > 0 { - func_flags |= bytecode::MakeFunctionFlags::ANNOTATIONS; + annotations_flag = bytecode::MakeFunctionFlags::ANNOTATIONS; emit!( self, Instruction::BuildMap { @@ -1961,27 +2170,63 @@ impl Compiler<'_> { ); } - // Pop the special type params symbol table - if type_params.is_some() { - self.pop_symbol_table(); - } + // Compile function body + let final_funcflags = funcflags | annotations_flag; + self.compile_function_body(name, parameters, body, is_async, final_funcflags)?; - // Create function with closure - self.make_closure(code, func_flags)?; + // Handle type params if present + if is_generic { + // SWAP to get function on top + // Stack: [type_params_tuple, function] -> [function, type_params_tuple] + emit!(self, Instruction::Swap { index: 2 }); - if let Some(value) = doc_str { - emit!(self, Instruction::Duplicate); - self.emit_load_const(ConstantData::Str { - value: value.into(), - }); - emit!(self, Instruction::Rotate2); - let doc = self.name("__doc__"); - emit!(self, Instruction::StoreAttr { idx: doc }); + // Call INTRINSIC_SET_FUNCTION_TYPE_PARAMS + emit!( + self, + Instruction::CallIntrinsic2 { + func: bytecode::IntrinsicFunction2::SetFunctionTypeParams, + } + ); + + // Return the function object from type params scope + emit!(self, Instruction::ReturnValue); + + // Set argcount for type params scope + self.current_code_info().metadata.argcount = num_typeparam_args as u32; + + // Exit type params scope and create closure + let type_params_code = self.exit_scope(); + + // Make closure for type params code + self.make_closure(type_params_code, bytecode::MakeFunctionFlags::empty())?; + + // Call the closure + if num_typeparam_args > 0 { + emit!( + self, + Instruction::Swap { + index: (num_typeparam_args + 1) as u32 + } + ); + emit!( + self, + Instruction::CallFunctionPositional { + nargs: num_typeparam_args as u32 + } + ); + } else { + // No arguments, just call the closure + emit!(self, Instruction::CallFunctionPositional { nargs: 0 }); + } } + // Apply decorators self.apply_decorators(decorator_list); - self.store_name(name) + // Store the function + self.store_name(name)?; + + Ok(()) } /// Determines if a variable should be CELL or FREE type @@ -2188,7 +2433,7 @@ impl Compiler<'_> { } /// Compile the class body into a code object - /// This is similar to CPython's compiler_class_body + // = compiler_class_body fn compile_class_body( &mut self, name: &str, @@ -2197,7 +2442,6 @@ impl Compiler<'_> { firstlineno: u32, ) -> CompileResult { // 1. Enter class scope - // Use enter_scope instead of push_output to match CPython let key = self.symbol_table_stack.len(); self.push_symbol_table(); self.enter_scope(name, CompilerScope::Class, key, firstlineno)?; @@ -5166,8 +5410,8 @@ impl EmitArg for ir::BlockIdx { /// Strips leading whitespace from a docstring. /// -/// The code has been ported from `_PyCompile_CleanDoc` in cpython. -/// `inspect.cleandoc` is also a good reference, but has a few incompatibilities. +/// `inspect.cleandoc` is a good reference, but has a few incompatibilities. +// = _PyCompile_CleanDoc fn clean_doc(doc: &str) -> String { let doc = expandtabs(doc, 8); // First pass: find minimum indentation of any non-blank lines diff --git a/compiler/codegen/src/ir.rs b/compiler/codegen/src/ir.rs index f2299892b3..ae4eef9cc9 100644 --- a/compiler/codegen/src/ir.rs +++ b/compiler/codegen/src/ir.rs @@ -101,6 +101,9 @@ pub struct CodeInfo { // Block stack for tracking nested control structures pub fblock: Vec, + + // Reference to the symbol table for this scope + pub symbol_table_index: usize, } impl CodeInfo { pub fn finalize_code(mut self, optimize: u8) -> crate::InternalResult { @@ -122,6 +125,7 @@ impl CodeInfo { static_attributes: _, in_inlined_comp: _, fblock: _, + symbol_table_index: _, } = self; let CodeUnitMetadata { @@ -318,7 +322,10 @@ impl CodeInfo { continue 'process_blocks; } } - stackdepth_push(&mut stack, &mut start_depths, block.next, depth); + // Only push next block if it's not NULL + if block.next != BlockIdx::NULL { + stackdepth_push(&mut stack, &mut start_depths, block.next, depth); + } } if DEBUG { eprintln!("DONE: {maxdepth}"); diff --git a/compiler/codegen/src/symboltable.rs b/compiler/codegen/src/symboltable.rs index e158514f87..8ab697f248 100644 --- a/compiler/codegen/src/symboltable.rs +++ b/compiler/codegen/src/symboltable.rs @@ -1012,6 +1012,28 @@ impl SymbolTableBuilder<'_> { context: ExpressionContext, ) -> SymbolTableResult { use ruff_python_ast::*; + + // Check for expressions not allowed in type parameters scope + if let Some(table) = self.tables.last() { + if table.typ == CompilerScope::TypeParams { + if let Some(keyword) = match expression { + Expr::Yield(_) | Expr::YieldFrom(_) => Some("yield"), + Expr::Await(_) => Some("await"), + Expr::Named(_) => Some("named"), + _ => None, + } { + return Err(SymbolTableError { + error: format!( + "{keyword} expression cannot be used within a type parameter" + ), + location: Some( + self.source_code.source_location(expression.range().start()), + ), + }); + } + } + } + match expression { Expr::BinOp(ExprBinOp { left, diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index ce70a5883d..ad07d3df72 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -3081,36 +3081,52 @@ mod _sqlite { } fn lstrip_sql(sql: &[u8]) -> Option<&[u8]> { - let mut pos = sql; - loop { - match pos.first()? { + let mut pos = 0; + + // This loop is borrowed from the SQLite source code. + while let Some(t_char) = sql.get(pos) { + match t_char { b' ' | b'\t' | b'\x0c' | b'\n' | b'\r' => { - pos = &pos[1..]; + // Skip whitespace. + pos += 1; } b'-' => { - if *pos.get(1)? == b'-' { - // line comments - pos = &pos[2..]; - while *pos.first()? != b'\n' { - pos = &pos[1..]; + // Skip line comments. + if sql.get(pos + 1) == Some(&b'-') { + pos += 2; + while let Some(&ch) = sql.get(pos) { + if ch == b'\n' { + break; + } + pos += 1; } + let _ = sql.get(pos)?; } else { - return Some(pos); + return Some(&sql[pos..]); } } b'/' => { - if *pos.get(1)? == b'*' { - // c style comments - pos = &pos[2..]; - while *pos.first()? != b'*' || *pos.get(1)? != b'/' { - pos = &pos[1..]; + // Skip C style comments. + if sql.get(pos + 1) == Some(&b'*') { + pos += 2; + while let Some(&ch) = sql.get(pos) { + if ch == b'*' && sql.get(pos + 1) == Some(&b'/') { + break; + } + pos += 1; } + let _ = sql.get(pos)?; + pos += 2; } else { - return Some(pos); + return Some(&sql[pos..]); } } - _ => return Some(pos), + _ => { + return Some(&sql[pos..]); + } } } + + None } } diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 94334d4a88..498330fa97 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -1038,9 +1038,6 @@ impl Constructor for PyType { }); } - // TODO: Flags is currently initialized with HAS_DICT. Should be - // updated when __slots__ are supported (toggling the flag off if - // a class has __slots__ defined). let heaptype_slots: Option>> = if let Some(x) = attributes.get(identifier!(vm, __slots__)) { let slots = if x.class().is(vm.ctx.types.str_type) { @@ -1072,7 +1069,12 @@ impl Constructor for PyType { let heaptype_member_count = heaptype_slots.as_ref().map(|x| x.len()).unwrap_or(0); let member_count: usize = base_member_count + heaptype_member_count; - let flags = PyTypeFlags::heap_type_flags() | PyTypeFlags::HAS_DICT; + let mut flags = PyTypeFlags::heap_type_flags(); + // Only add HAS_DICT and MANAGED_DICT if __slots__ is not defined. + if heaptype_slots.is_none() { + flags |= PyTypeFlags::HAS_DICT | PyTypeFlags::MANAGED_DICT; + } + let (slots, heaptype_ext) = { let slots = PyTypeSlots { flags, diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 59cfc0ac68..199596b9ed 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -697,8 +697,11 @@ impl ExecutingFrame<'_> { } bytecode::Instruction::Swap { index } => { let len = self.state.stack.len(); - let i = len - 1; - let j = len - 1 - index.get(arg) as usize; + let i = len - 1; // TOS index + let index_val = index.get(arg) as usize; + // SWAP(i) swaps TOS with element i positions down from TOS + // So the target index is len - index_val + let j = len - index_val; self.state.stack.swap(i, j); Ok(None) } diff --git a/vm/src/stdlib/stat.rs b/vm/src/stdlib/stat.rs index fdb88b37f0..bd5b2e2870 100644 --- a/vm/src/stdlib/stat.rs +++ b/vm/src/stdlib/stat.rs @@ -72,6 +72,11 @@ mod stat { // TODO: RUSTPYTHON Support BSD // https://man.freebsd.org/cgi/man.cgi?stat(2) + + #[cfg(target_os = "macos")] + #[pyattr] + pub const S_IFWHT: Mode = 0o160000; + #[cfg(not(target_os = "macos"))] #[pyattr] pub const S_IFWHT: Mode = 0; diff --git a/vm/src/stdlib/symtable.rs b/vm/src/stdlib/symtable.rs index e37a2f45bb..18c4836954 100644 --- a/vm/src/stdlib/symtable.rs +++ b/vm/src/stdlib/symtable.rs @@ -70,7 +70,10 @@ mod symtable { #[pymethod] fn is_optimized(&self) -> bool { - self.symtable.typ == CompilerScope::Function + matches!( + self.symtable.typ, + CompilerScope::Function | CompilerScope::AsyncFunction + ) } #[pymethod] diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 88e0f231cb..35a387ed78 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -122,6 +122,7 @@ bitflags! { #[derive(Copy, Clone, Debug, PartialEq)] #[non_exhaustive] pub struct PyTypeFlags: u64 { + const MANAGED_DICT = 1 << 4; const IMMUTABLETYPE = 1 << 8; const HEAPTYPE = 1 << 9; const BASETYPE = 1 << 10;