diff options
Diffstat (limited to 'awx/lib/site-packages')
181 files changed, 2850 insertions, 2346 deletions
diff --git a/awx/lib/site-packages/README b/awx/lib/site-packages/README index d9d46fd10f..aa3205363d 100644 --- a/awx/lib/site-packages/README +++ b/awx/lib/site-packages/README @@ -1,17 +1,17 @@ Local versions of third-party packages required by AWX. Package names and versions are listed below, along with notes on which files are included. -amqp-1.0.11 (amqp/*) +amqp-1.0.13 (amqp/*) anyjson-0.3.3 (anyjson/*) -billiard-2.7.3.28 (billiard/*, funtests/*, excluded _billiard.so) -celery-3.0.19 (celery/*, excluded bin/celery* and bin/camqadm) -django-celery-3.0.17 (djcelery/*, excluded bin/djcelerymon) -django-extensions-1.1.1 (django_extensions/*) +billiard-2.7.3.32 (billiard/*, funtests/*, excluded _billiard.so) +celery-3.0.22 (celery/*, excluded bin/celery* and bin/camqadm) +django-celery-3.0.21 (djcelery/*, excluded bin/djcelerymon) +django-extensions-1.2.0 (django_extensions/*) django-jsonfield-0.9.10 (jsonfield/*) -django-taggit-0.10a1 (taggit/*) -djangorestframework-2.3.5 (rest_framework/*) +django-taggit-0.10 (taggit/*) +djangorestframework-2.3.7 (rest_framework/*) importlib-1.0.2 (importlib/*, needed for Python 2.6 support) -kombu-2.5.10 (kombu/*) +kombu-2.5.14 (kombu/*) Markdown-2.3.1 (markdown/*, excluded bin/markdown_py) ordereddict-1.1 (ordereddict.py, needed for Python 2.6 support) pexpect-2.4 (pexpect.py, pxssh.py, fdpexpect.py, FSM.py, screen.py, ANSI.py) @@ -19,4 +19,4 @@ python-dateutil-2.1 (dateutil/*) pytz-2013b (pytz/*) requests-1.2.3 (requests/*) six-1.3.0 (six.py) -South-0.8.1 (south/*) +South-0.8.2 (south/*) diff --git a/awx/lib/site-packages/amqp/__init__.py b/awx/lib/site-packages/amqp/__init__.py index 82abe7f721..fca3bb317d 100644 --- a/awx/lib/site-packages/amqp/__init__.py +++ b/awx/lib/site-packages/amqp/__init__.py @@ -16,7 +16,7 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 from __future__ import absolute_import -VERSION = (1, 0, 11) +VERSION = (1, 0, 13) __version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:]) __author__ = 'Barry Pederson' __maintainer__ = 'Ask Solem' diff --git a/awx/lib/site-packages/amqp/basic_message.py b/awx/lib/site-packages/amqp/basic_message.py index dc7c7a1237..192ede90b2 100644 --- a/awx/lib/site-packages/amqp/basic_message.py +++ b/awx/lib/site-packages/amqp/basic_message.py @@ -44,7 +44,7 @@ class Message(GenericContent): ('cluster_id', 'shortstr') ] - def __init__(self, body='', children=None, **properties): + def __init__(self, body='', children=None, channel=None, **properties): """Expected arg types body: string @@ -107,6 +107,7 @@ class Message(GenericContent): """ super(Message, self).__init__(**properties) self.body = body + self.channel = channel def __eq__(self, other): """Check if the properties and bodies of this Message and another diff --git a/awx/lib/site-packages/amqp/method_framing.py b/awx/lib/site-packages/amqp/method_framing.py index 0225ddd35f..a26db6672b 100644 --- a/awx/lib/site-packages/amqp/method_framing.py +++ b/awx/lib/site-packages/amqp/method_framing.py @@ -16,9 +16,8 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 from __future__ import absolute_import -from collections import defaultdict +from collections import defaultdict, deque from struct import pack, unpack -from Queue import Queue try: bytes @@ -61,12 +60,17 @@ class _PartialMessage(object): self.complete = (self.body_size == 0) def add_payload(self, payload): - self.body_parts.append(payload) + parts = self.body_parts self.body_received += len(payload) - if self.body_received == self.body_size: - self.msg.body = bytes().join(self.body_parts) + if parts: + parts.append(payload) + self.msg.body = bytes().join(parts) + else: + self.msg.body = payload self.complete = True + else: + parts.append(payload) class MethodReader(object): @@ -86,7 +90,7 @@ class MethodReader(object): def __init__(self, source): self.source = source - self.queue = Queue() + self.queue = deque() self.running = False self.partial_messages = {} self.heartbeats = 0 @@ -94,32 +98,33 @@ class MethodReader(object): self.expected_types = defaultdict(lambda: 1) # not an actual byte count, just incremented whenever we receive self.bytes_recv = 0 + self._quick_put = self.queue.append + self._quick_get = self.queue.popleft def _next_method(self): """Read the next method from the source, once one complete method has been assembled it is placed in the internal queue.""" - empty = self.queue.empty + queue = self.queue + put = self._quick_put read_frame = self.source.read_frame - while empty(): + while not queue: try: frame_type, channel, payload = read_frame() except Exception, e: # # Connection was closed? Framing Error? # - self.queue.put(e) + put(e) break self.bytes_recv += 1 if frame_type not in (self.expected_types[channel], 8): - self.queue.put(( + put(( channel, AMQPError( 'Received frame type %s while expecting type: %s' % ( - frame_type, self.expected_types[channel]) - ), - )) + frame_type, self.expected_types[channel])))) elif frame_type == 1: self._process_method_frame(channel, payload) elif frame_type == 2: @@ -144,7 +149,7 @@ class MethodReader(object): self.partial_messages[channel] = _PartialMessage(method_sig, args) self.expected_types[channel] = 2 else: - self.queue.put((channel, method_sig, args, None)) + self._quick_put((channel, method_sig, args, None)) def _process_content_header(self, channel, payload): """Process Content Header frames""" @@ -155,8 +160,8 @@ class MethodReader(object): # # a bodyless message, we're done # - self.queue.put((channel, partial.method_sig, - partial.args, partial.msg)) + self._quick_put((channel, partial.method_sig, + partial.args, partial.msg)) self.partial_messages.pop(channel, None) self.expected_types[channel] = 1 else: @@ -174,15 +179,15 @@ class MethodReader(object): # Stick the message in the queue and go back to # waiting for method frames # - self.queue.put((channel, partial.method_sig, - partial.args, partial.msg)) + self._quick_put((channel, partial.method_sig, + partial.args, partial.msg)) self.partial_messages.pop(channel, None) self.expected_types[channel] = 1 def read_method(self): """Read a method from the peer.""" self._next_method() - m = self.queue.get() + m = self._quick_get() if isinstance(m, Exception): raise m if isinstance(m, tuple) and isinstance(m[1], AMQPError): diff --git a/awx/lib/site-packages/amqp/transport.py b/awx/lib/site-packages/amqp/transport.py index 8092c5fec3..b441a11ced 100644 --- a/awx/lib/site-packages/amqp/transport.py +++ b/awx/lib/site-packages/amqp/transport.py @@ -52,6 +52,8 @@ from .exceptions import AMQPError AMQP_PORT = 5672 +EMPTY_BUFFER = bytes() + # Yes, Advanced Message Queuing Protocol Protocol is redundant AMQP_PROTOCOL_HEADER = 'AMQP\x01\x01\x00\x09'.encode('latin_1') @@ -139,11 +141,12 @@ class _AbstractTransport(object): self.sock.close() self.sock = None - def read_frame(self): + def read_frame(self, unpack=unpack): """Read an AMQP frame.""" - frame_type, channel, size = unpack('>BHI', self._read(7, True)) - payload = self._read(size) - ch = ord(self._read(1)) + read = self._read + frame_type, channel, size = unpack('>BHI', read(7, True)) + payload = read(size) + ch = ord(read(1)) if ch == 206: # '\xce' return frame_type, channel, payload else: @@ -164,7 +167,7 @@ class SSLTransport(_AbstractTransport): def __init__(self, host, connect_timeout, ssl): if isinstance(ssl, dict): self.sslopts = ssl - self.sslobj = None + self._read_buffer = EMPTY_BUFFER super(SSLTransport, self).__init__(host, connect_timeout) def _setup_transport(self): @@ -173,43 +176,51 @@ class SSLTransport(_AbstractTransport): lower version.""" if HAVE_PY26_SSL: if hasattr(self, 'sslopts'): - self.sslobj = ssl.wrap_socket(self.sock, **self.sslopts) + self.sock = ssl.wrap_socket(self.sock, **self.sslopts) else: - self.sslobj = ssl.wrap_socket(self.sock) - self.sslobj.do_handshake() + self.sock = ssl.wrap_socket(self.sock) + self.sock.do_handshake() else: - self.sslobj = socket.ssl(self.sock) + self.sock = socket.ssl(self.sock) + self._quick_recv = self.sock.read def _shutdown_transport(self): """Unwrap a Python 2.6 SSL socket, so we can call shutdown()""" - if HAVE_PY26_SSL and (self.sslobj is not None): - self.sock = self.sslobj.unwrap() - self.sslobj = None - - def _read(self, n, initial=False): - """It seems that SSL Objects read() method may not supply as much - as you're asking for, at least with extremely large messages. - somewhere > 16K - found this in the test_channel.py test_large - unittest.""" - result = '' - - while len(result) < n: + if HAVE_PY26_SSL and self.sock is not None: try: - s = self.sslobj.read(n - len(result)) + unwrap = self.sock.unwrap + except AttributeError: + return + self.sock = unwrap() + + def _read(self, n, initial=False, + _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)): + # According to SSL_read(3), it can at most return 16kb of data. + # Thus, we use an internal read buffer like TCPTransport._read + # to get the exact number of bytes wanted. + recv = self._quick_recv + rbuf = self._read_buffer + while len(rbuf) < n: + try: + s = recv(131072) # see note above except socket.error, exc: - if not initial and exc.errno in (errno.EAGAIN, errno.EINTR): + # ssl.sock.read may cause ENOENT if the + # operation couldn't be performed (Issue celery#1414). + if not initial and exc.errno in _errnos: continue - raise + raise exc if not s: raise IOError('Socket closed') - result += s + rbuf += s + result, self._read_buffer = rbuf[:n], rbuf[n:] return result def _write(self, s): """Write a string out to the SSL socket fully.""" + write = self.sock.write while s: - n = self.sslobj.write(s) + n = write(s) if not n: raise IOError('Socket closed') s = s[n:] @@ -222,24 +233,25 @@ class TCPTransport(_AbstractTransport): """Setup to _write() directly to the socket, and do our own buffered reads.""" self._write = self.sock.sendall - self._read_buffer = bytes() + self._read_buffer = EMPTY_BUFFER + self._quick_recv = self.sock.recv - def _read(self, n, initial=False): + def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): """Read exactly n bytes from the socket""" - while len(self._read_buffer) < n: + recv = self._quick_recv + rbuf = self._read_buffer + while len(rbuf) < n: try: - s = self.sock.recv(65536) + s = recv(131072) except socket.error, exc: - if not initial and exc.errno in (errno.EAGAIN, errno.EINTR): + if not initial and exc.errno in _errnos: continue raise if not s: raise IOError('Socket closed') - self._read_buffer += s - - result = self._read_buffer[:n] - self._read_buffer = self._read_buffer[n:] + rbuf += s + result, self._read_buffer = rbuf[:n], rbuf[n:] return result diff --git a/awx/lib/site-packages/billiard/__init__.py b/awx/lib/site-packages/billiard/__init__.py index 5b22a4fa1a..846d01c168 100644 --- a/awx/lib/site-packages/billiard/__init__.py +++ b/awx/lib/site-packages/billiard/__init__.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from __future__ import with_statement -VERSION = (2, 7, 3, 28) +VERSION = (2, 7, 3, 32) __version__ = ".".join(map(str, VERSION[0:4])) + "".join(VERSION[4:]) __author__ = 'R Oudkerk / Python Software Foundation' __author_email__ = 'python-dev@python.org' @@ -232,7 +232,11 @@ def JoinableQueue(maxsize=0): return JoinableQueue(maxsize) -def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None): +def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None, + timeout=None, soft_timeout=None, lost_worker_timeout=None, + max_restarts=None, max_restart_freq=1, on_process_up=None, + on_process_down=None, on_timeout_set=None, on_timeout_cancel=None, + threads=True, semaphore=None, putlocks=False, allow_restart=False): ''' Returns a process pool object ''' diff --git a/awx/lib/site-packages/billiard/_ext.py b/awx/lib/site-packages/billiard/_ext.py index a02a156669..7d9caf01ae 100644 --- a/awx/lib/site-packages/billiard/_ext.py +++ b/awx/lib/site-packages/billiard/_ext.py @@ -4,6 +4,11 @@ import sys supports_exec = True +try: + import _winapi as win32 +except ImportError: # pragma: no cover + win32 = None + if sys.platform.startswith("java"): _billiard = None else: @@ -18,7 +23,8 @@ else: from multiprocessing.connection import Connection # noqa PipeConnection = getattr(_billiard, "PipeConnection", None) - win32 = getattr(_billiard, "win32", None) + if win32 is None: + win32 = getattr(_billiard, "win32", None) # noqa def ensure_multiprocessing(): diff --git a/awx/lib/site-packages/billiard/common.py b/awx/lib/site-packages/billiard/common.py index 04c37ea67d..5c367fd879 100644 --- a/awx/lib/site-packages/billiard/common.py +++ b/awx/lib/site-packages/billiard/common.py @@ -1,12 +1,41 @@ +# -*- coding: utf-8 -*- +""" +This module contains utilities added by billiard, to keep +"non-core" functionality out of ``.util``.""" from __future__ import absolute_import import signal import sys from time import time +import pickle as pypickle +try: + import cPickle as cpickle +except ImportError: # pragma: no cover + cpickle = None # noqa from .exceptions import RestartFreqExceeded +if sys.version_info < (2, 6): # pragma: no cover + # cPickle does not use absolute_imports + pickle = pypickle + pickle_load = pypickle.load + pickle_loads = pypickle.loads +else: + pickle = cpickle or pypickle + pickle_load = pickle.load + pickle_loads = pickle.loads + +# cPickle.loads does not support buffer() objects, +# but we can just create a StringIO and use load. +if sys.version_info[0] == 3: + from io import BytesIO +else: + try: + from cStringIO import StringIO as BytesIO # noqa + except ImportError: + from StringIO import StringIO as BytesIO # noqa + TERMSIGS = ( 'SIGHUP', 'SIGQUIT', @@ -30,6 +59,11 @@ TERMSIGS = ( ) +def pickle_loads(s, load=pickle_load): + # used to support buffer objects + return load(BytesIO(s)) + + def _shutdown_cleanup(signum, frame): sys.exit(-(256 - signum)) diff --git a/awx/lib/site-packages/billiard/pool.py b/awx/lib/site-packages/billiard/pool.py index 71c637656a..33e0fe0d29 100644 --- a/awx/lib/site-packages/billiard/pool.py +++ b/awx/lib/site-packages/billiard/pool.py @@ -17,7 +17,6 @@ from __future__ import with_statement import collections import errno import itertools -import logging import os import platform import signal @@ -29,7 +28,7 @@ import warnings from . import Event, Process, cpu_count from . import util -from .common import reset_signals, restart_state +from .common import pickle_loads, reset_signals, restart_state from .compat import get_errno from .einfo import ExceptionInfo from .exceptions import ( @@ -163,15 +162,6 @@ class LaxBoundedSemaphore(_Semaphore): if self._value < self._initial_value: self._value += 1 cond.notify_all() - if __debug__: - self._note( - "%s.release: success, value=%s", self, self._value, - ) - else: - if __debug__: - self._note( - "%s.release: success, value=%s (unchanged)" % ( - self, self._value)) def clear(self): while self._value < self._initial_value: @@ -184,14 +174,6 @@ class LaxBoundedSemaphore(_Semaphore): if self._Semaphore__value < self._initial_value: self._Semaphore__value += 1 cond.notifyAll() - if __debug__: - self._note("%s.release: success, value=%s", - self, self._Semaphore__value) - else: - if __debug__: - self._note( - "%s.release: success, value=%s (unchanged)" % ( - self, self._Semaphore__value)) def clear(self): # noqa while self._Semaphore__value < self._initial_value: @@ -233,28 +215,26 @@ def soft_timeout_sighandler(signum, frame): def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, sentinel=None): - # Re-init logging system. - # Workaround for http://bugs.python.org/issue6721#msg140215 - # Python logging module uses RLock() objects which are broken after - # fork. This can result in a deadlock (Issue #496). - logger_names = logging.Logger.manager.loggerDict.keys() - logger_names.append(None) # for root logger - for name in logger_names: - for handler in logging.getLogger(name).handlers: - handler.createLock() - logging._lock = threading.RLock() - pid = os.getpid() assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) put = outqueue.put get = inqueue.get + loads = pickle_loads if hasattr(inqueue, '_reader'): - def poll(timeout): - if inqueue._reader.poll(timeout): - return True, get() - return False, None + if hasattr(inqueue, 'get_payload') and inqueue.get_payload: + get_payload = inqueue.get_payload + + def poll(timeout): + if inqueue._reader.poll(timeout): + return True, loads(get_payload()) + return False, None + else: + def poll(timeout): + if inqueue._reader.poll(timeout): + return True, get() + return False, None else: def poll(timeout): # noqa @@ -1236,8 +1216,13 @@ class Pool(object): return result def terminate_job(self, pid, sig=None): - self.signalled.add(pid) - _kill(pid, sig or signal.SIGTERM) + try: + _kill(pid, sig or signal.SIGTERM) + except OSError, exc: + if get_errno(exc) != errno.ESRCH: + raise + else: + self.signalled.add(pid) def map_async(self, func, iterable, chunksize=None, callback=None, error_callback=None): diff --git a/awx/lib/site-packages/billiard/process.py b/awx/lib/site-packages/billiard/process.py index 79c850bd4b..b8418229db 100644 --- a/awx/lib/site-packages/billiard/process.py +++ b/awx/lib/site-packages/billiard/process.py @@ -19,6 +19,8 @@ import sys import signal import itertools import binascii +import logging +import threading from .compat import bytes try: @@ -45,9 +47,17 @@ def current_process(): def _cleanup(): # check for processes which have finished - for p in list(_current_process._children): - if p._popen.poll() is not None: - _current_process._children.discard(p) + if _current_process is not None: + for p in list(_current_process._children): + if p._popen.poll() is not None: + _current_process._children.discard(p) + + +def _maybe_flush(f): + try: + f.flush() + except (AttributeError, EnvironmentError, NotImplementedError): + pass def active_children(_cleanup=_cleanup): @@ -59,7 +69,9 @@ def active_children(_cleanup=_cleanup): except TypeError: # called after gc collect so _cleanup does not exist anymore return [] - return list(_current_process._children) + if _current_process is not None: + return list(_current_process._children) + return [] class Process(object): @@ -242,6 +254,18 @@ class Process(object): pass old_process = _current_process _current_process = self + + # Re-init logging system. + # Workaround for http://bugs.python.org/issue6721#msg140215 + # Python logging module uses RLock() objects which are broken after + # fork. This can result in a deadlock (Celery Issue #496). + logger_names = logging.Logger.manager.loggerDict.keys() + logger_names.append(None) # for root logger + for name in logger_names: + for handler in logging.getLogger(name).handlers: + handler.createLock() + logging._lock = threading.RLock() + try: util._finalizer_registry.clear() util._run_after_forkers() @@ -262,7 +286,7 @@ class Process(object): exitcode = e.args[0] else: sys.stderr.write(str(e.args[0]) + '\n') - sys.stderr.flush() + _maybe_flush(sys.stderr) exitcode = 0 if isinstance(e.args[0], str) else 1 except: exitcode = 1 @@ -273,8 +297,8 @@ class Process(object): finally: util.info('process %s exiting with exitcode %d', self.pid, exitcode) - sys.stdout.flush() - sys.stderr.flush() + _maybe_flush(sys.stdout) + _maybe_flush(sys.stderr) return exitcode # diff --git a/awx/lib/site-packages/billiard/queues.py b/awx/lib/site-packages/billiard/queues.py index a44c2b47a7..2554a5ef07 100644 --- a/awx/lib/site-packages/billiard/queues.py +++ b/awx/lib/site-packages/billiard/queues.py @@ -334,6 +334,10 @@ class SimpleQueue(object): def _make_methods(self): recv = self._reader.recv + try: + recv_payload = self._reader.recv_payload + except AttributeError: + recv_payload = None # C extension not installed rlock = self._rlock def get(): @@ -341,6 +345,12 @@ class SimpleQueue(object): return recv() self.get = get + if recv_payload is not None: + def get_payload(): + with rlock: + return recv_payload() + self.get_payload = get_payload + if self._wlock is None: # writes to a message oriented win32 pipe are atomic self.put = self._writer.send diff --git a/awx/lib/site-packages/billiard/util.py b/awx/lib/site-packages/billiard/util.py index f669512049..76c8431a6f 100644 --- a/awx/lib/site-packages/billiard/util.py +++ b/awx/lib/site-packages/billiard/util.py @@ -244,7 +244,9 @@ class Finalize(object): return x + '>' -def _run_finalizers(minpriority=None): +def _run_finalizers(minpriority=None, + _finalizer_registry=_finalizer_registry, + sub_debug=sub_debug, error=error): ''' Run all finalizers whose exit priority is not None and at least minpriority @@ -280,7 +282,9 @@ def is_exiting(): return _exiting or _exiting is None -def _exit_function(): +def _exit_function(info=info, debug=debug, + active_children=active_children, + _run_finalizers=_run_finalizers): ''' Clean up on exit ''' diff --git a/awx/lib/site-packages/celery/__compat__.py b/awx/lib/site-packages/celery/__compat__.py index e09772346c..08700719cb 100644 --- a/awx/lib/site-packages/celery/__compat__.py +++ b/awx/lib/site-packages/celery/__compat__.py @@ -14,7 +14,12 @@ from __future__ import absolute_import import operator import sys -from functools import reduce +# import fails in python 2.5. fallback to reduce in stdlib +try: + from functools import reduce +except ImportError: + pass + from importlib import import_module from types import ModuleType diff --git a/awx/lib/site-packages/celery/__init__.py b/awx/lib/site-packages/celery/__init__.py index 0aa88032cf..49a5ed9baf 100644 --- a/awx/lib/site-packages/celery/__init__.py +++ b/awx/lib/site-packages/celery/__init__.py @@ -8,7 +8,7 @@ from __future__ import absolute_import SERIES = 'Chiastic Slide' -VERSION = (3, 0, 19) +VERSION = (3, 0, 22) __version__ = '.'.join(str(p) for p in VERSION[0:3]) + ''.join(VERSION[3:]) __author__ = 'Ask Solem' __contact__ = 'ask@celeryproject.org' diff --git a/awx/lib/site-packages/celery/app/builtins.py b/awx/lib/site-packages/celery/app/builtins.py index 0c3582f5ed..1938af21cc 100644 --- a/awx/lib/site-packages/celery/app/builtins.py +++ b/awx/lib/site-packages/celery/app/builtins.py @@ -307,8 +307,9 @@ def add_chord_task(app): accept_magic_kwargs = False ignore_result = False - def run(self, header, body, partial_args=(), interval=1, countdown=1, - max_retries=None, propagate=None, eager=False, **kwargs): + def run(self, header, body, partial_args=(), interval=None, + countdown=1, max_retries=None, propagate=None, + eager=False, **kwargs): propagate = default_propagate if propagate is None else propagate group_id = uuid() AsyncResult = self.app.AsyncResult diff --git a/awx/lib/site-packages/celery/app/task.py b/awx/lib/site-packages/celery/app/task.py index d3c0fe3522..94282ecbf2 100644 --- a/awx/lib/site-packages/celery/app/task.py +++ b/awx/lib/site-packages/celery/app/task.py @@ -198,11 +198,11 @@ class Task(object): serializer = None #: Hard time limit. - #: Defaults to the :setting:`CELERY_TASK_TIME_LIMIT` setting. + #: Defaults to the :setting:`CELERYD_TASK_TIME_LIMIT` setting. time_limit = None #: Soft time limit. - #: Defaults to the :setting:`CELERY_TASK_SOFT_TIME_LIMIT` setting. + #: Defaults to the :setting:`CELERYD_TASK_SOFT_TIME_LIMIT` setting. soft_time_limit = None #: The result store backend used for this task. @@ -459,7 +459,8 @@ class Task(object): args = (self.__self__, ) + tuple(args) if conf.CELERY_ALWAYS_EAGER: - return self.apply(args, kwargs, task_id=task_id, **options) + return self.apply(args, kwargs, task_id=task_id, + link=link, link_error=link_error, **options) options = dict(extract_exec_options(self), **options) options = router.route(options, self.name, args, kwargs) @@ -580,7 +581,8 @@ class Task(object): raise ret return ret - def apply(self, args=None, kwargs=None, **options): + def apply(self, args=None, kwargs=None, + link=None, link_error=None, **options): """Execute this task locally, by blocking until the task returns. :param args: positional arguments passed on to the task. @@ -614,6 +616,8 @@ class Task(object): 'is_eager': True, 'logfile': options.get('logfile'), 'loglevel': options.get('loglevel', 0), + 'callbacks': maybe_list(link), + 'errbacks': maybe_list(link_error), 'delivery_info': {'is_eager': True}} if self.accept_magic_kwargs: default_kwargs = {'task_name': task.name, diff --git a/awx/lib/site-packages/celery/apps/worker.py b/awx/lib/site-packages/celery/apps/worker.py index c3e6d6da42..2d6d67c19a 100644 --- a/awx/lib/site-packages/celery/apps/worker.py +++ b/awx/lib/site-packages/celery/apps/worker.py @@ -251,7 +251,7 @@ class Worker(configurated): 'version': VERSION_BANNER, 'conninfo': self.app.connection().as_uri(), 'concurrency': concurrency, - 'platform': _platform.platform(), + 'platform': safe_str(_platform.platform()), 'events': events, 'queues': app.amqp.queues.format(indent=0, indent_first=False), }).splitlines() diff --git a/awx/lib/site-packages/celery/backends/mongodb.py b/awx/lib/site-packages/celery/backends/mongodb.py index 2027d66b3b..30bbf673f3 100644 --- a/awx/lib/site-packages/celery/backends/mongodb.py +++ b/awx/lib/site-packages/celery/backends/mongodb.py @@ -114,6 +114,8 @@ class MongoBackend(BaseDictBackend): if self._connection is not None: # MongoDB connection will be closed automatically when object # goes out of scope + del(self.collection) + del(self.database) self._connection = None def _store_result(self, task_id, result, status, traceback=None): @@ -124,7 +126,7 @@ class MongoBackend(BaseDictBackend): 'date_done': datetime.utcnow(), 'traceback': Binary(self.encode(traceback)), 'children': Binary(self.encode(self.current_task_children()))} - self.collection.save(meta, safe=True) + self.collection.save(meta) return result @@ -151,7 +153,7 @@ class MongoBackend(BaseDictBackend): meta = {'_id': group_id, 'result': Binary(self.encode(result)), 'date_done': datetime.utcnow()} - self.collection.save(meta, safe=True) + self.collection.save(meta) return result @@ -183,7 +185,7 @@ class MongoBackend(BaseDictBackend): # By using safe=True, this will wait until it receives a response from # the server. Likewise, it will raise an OperationsError if the # response was unable to be completed. - self.collection.remove({'_id': task_id}, safe=True) + self.collection.remove({'_id': task_id}) def cleanup(self): """Delete expired metadata.""" diff --git a/awx/lib/site-packages/celery/bin/celery.py b/awx/lib/site-packages/celery/bin/celery.py index d92a497ed6..c0dd4a8e1e 100644 --- a/awx/lib/site-packages/celery/bin/celery.py +++ b/awx/lib/site-packages/celery/bin/celery.py @@ -614,7 +614,7 @@ class control(_RemoteControl): def call(self, method, *args, **options): # XXX Python 2.5 doesn't support X(*args, reply=True, **kwargs) return getattr(self.app.control, method)( - *args, **dict(options, retry=True)) + *args, **dict(options, reply=True)) def pool_grow(self, method, n=1, **kwargs): """[N=1]""" @@ -866,7 +866,7 @@ class CeleryCommand(BaseCommand): cls = self.commands.get(command) or self.commands['help'] try: return cls(app=self.app).run_from_argv(self.prog_name, argv) - except (TypeError, Error): + except Error: return self.execute('help', argv) def remove_options_at_beginning(self, argv, index=0): diff --git a/awx/lib/site-packages/celery/canvas.py b/awx/lib/site-packages/celery/canvas.py index d8dcfabcad..0eb360f6de 100644 --- a/awx/lib/site-packages/celery/canvas.py +++ b/awx/lib/site-packages/celery/canvas.py @@ -267,7 +267,8 @@ class Signature(dict): class chain(Signature): def __init__(self, *tasks, **options): - tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks + tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0]) + else tasks) Signature.__init__( self, 'celery.chain', (), {'tasks': tasks}, **options ) @@ -283,7 +284,7 @@ class chain(Signature): tasks = d['kwargs']['tasks'] if d['args'] and tasks: # partial args passed on to first task in chain (Issue #1057). - tasks[0]['args'] = d['args'] + tasks[0]['args'] + tasks[0]['args'] = tasks[0]._merge(d['args'])[0] return chain(*d['kwargs']['tasks'], **kwdict(d['options'])) @property @@ -392,7 +393,7 @@ class group(Signature): if d['args'] and tasks: # partial args passed on to all tasks in the group (Issue #1057). for task in tasks: - task['args'] = d['args'] + task['args'] + task['args'] = task._merge(d['args'])[0] return group(tasks, **kwdict(d['options'])) def __call__(self, *partial_args, **options): diff --git a/awx/lib/site-packages/celery/concurrency/eventlet.py b/awx/lib/site-packages/celery/concurrency/eventlet.py index fd97269f61..5f8d68750f 100644 --- a/awx/lib/site-packages/celery/concurrency/eventlet.py +++ b/awx/lib/site-packages/celery/concurrency/eventlet.py @@ -34,7 +34,8 @@ if not EVENTLET_NOPATCH and not PATCHED[0]: import eventlet import eventlet.debug eventlet.monkey_patch() - eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK) + if EVENTLET_DBLOCK: + eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK) from time import time diff --git a/awx/lib/site-packages/celery/contrib/migrate.py b/awx/lib/site-packages/celery/contrib/migrate.py index 9e54979455..76fe1db2fa 100644 --- a/awx/lib/site-packages/celery/contrib/migrate.py +++ b/awx/lib/site-packages/celery/contrib/migrate.py @@ -236,6 +236,8 @@ def start_filter(app, conn, filter, limit=None, timeout=1.0, consume_from=None, state=None, **kwargs): state = state or State() queues = prepare_queues(queues) + consume_from = [_maybe_queue(app, q) + for q in consume_from or queues.keys()] if isinstance(tasks, basestring): tasks = set(tasks.split(',')) if tasks is None: diff --git a/awx/lib/site-packages/celery/datastructures.py b/awx/lib/site-packages/celery/datastructures.py index f3e9c2e2b6..67f6c66b39 100644 --- a/awx/lib/site-packages/celery/datastructures.py +++ b/awx/lib/site-packages/celery/datastructures.py @@ -472,7 +472,10 @@ class LimitedSet(object): if time.time() < item[0] + self.expires: heappush(H, item) break - self._data.pop(item[1]) + try: + self._data.pop(item[1]) + except KeyError: # out of sync with heap + pass i += 1 def update(self, other, heappush=heappush): diff --git a/awx/lib/site-packages/celery/events/state.py b/awx/lib/site-packages/celery/events/state.py index 0afbec5383..8c129c2a38 100644 --- a/awx/lib/site-packages/celery/events/state.py +++ b/awx/lib/site-packages/celery/events/state.py @@ -210,11 +210,20 @@ class State(object): task_count = 0 def __init__(self, callback=None, + workers=None, tasks=None, taskheap=None, max_workers_in_memory=5000, max_tasks_in_memory=10000): - self.workers = LRUCache(limit=max_workers_in_memory) - self.tasks = LRUCache(limit=max_tasks_in_memory) self.event_callback = callback + self.workers = (LRUCache(max_workers_in_memory) + if workers is None else workers) + self.tasks = (LRUCache(max_tasks_in_memory) + if tasks is None else tasks) + self._taskheap = None # reserved for __reduce__ in 3.1 + self.max_workers_in_memory = max_workers_in_memory + self.max_tasks_in_memory = max_tasks_in_memory self._mutex = threading.Lock() + self.handlers = {'task': self.task_event, + 'worker': self.worker_event} + self._get_handler = self.handlers.__getitem__ def freeze_while(self, fun, *args, **kwargs): clear_after = kwargs.pop('clear_after', False) @@ -295,11 +304,14 @@ class State(object): with self._mutex: return self._dispatch_event(event) - def _dispatch_event(self, event): + def _dispatch_event(self, event, kwdict=kwdict): self.event_count += 1 event = kwdict(event) group, _, subject = event['type'].partition('-') - getattr(self, group + '_event')(subject, event) + try: + self._get_handler(group)(subject, event) + except KeyError: + pass if self.event_callback: self.event_callback(self, event) @@ -356,14 +368,10 @@ class State(object): return '<ClusterState: events=%s tasks=%s>' % (self.event_count, self.task_count) - def __getstate__(self): - d = dict(vars(self)) - d.pop('_mutex') - return d - - def __setstate__(self, state): - self.__dict__ = state - self._mutex = threading.Lock() - + def __reduce__(self): + return self.__class__, ( + self.event_callback, self.workers, self.tasks, None, + self.max_workers_in_memory, self.max_tasks_in_memory, + ) state = State() diff --git a/awx/lib/site-packages/celery/loaders/__init__.py b/awx/lib/site-packages/celery/loaders/__init__.py index 6f8aea72e9..1bd2baafcb 100644 --- a/awx/lib/site-packages/celery/loaders/__init__.py +++ b/awx/lib/site-packages/celery/loaders/__init__.py @@ -11,7 +11,7 @@ from __future__ import absolute_import from celery._state import current_app from celery.utils import deprecated -from celery.utils.imports import symbol_by_name +from celery.utils.imports import symbol_by_name, import_from_cwd LOADER_ALIASES = {'app': 'celery.loaders.app:AppLoader', 'default': 'celery.loaders.default:Loader', @@ -20,7 +20,7 @@ LOADER_ALIASES = {'app': 'celery.loaders.app:AppLoader', def get_loader_cls(loader): """Get loader class by name/alias""" - return symbol_by_name(loader, LOADER_ALIASES) + return symbol_by_name(loader, LOADER_ALIASES, imp=import_from_cwd) @deprecated(deprecation='2.5', removal='4.0', diff --git a/awx/lib/site-packages/celery/platforms.py b/awx/lib/site-packages/celery/platforms.py index 7919cb10b8..3e6b0f5338 100644 --- a/awx/lib/site-packages/celery/platforms.py +++ b/awx/lib/site-packages/celery/platforms.py @@ -48,6 +48,12 @@ PIDFILE_MODE = ((os.R_OK | os.W_OK) << 6) | ((os.R_OK) << 3) | ((os.R_OK)) PIDLOCKED = """ERROR: Pidfile (%s) already exists. Seems we're already running? (pid: %s)""" +try: + from io import UnsupportedOperation + FILENO_ERRORS = (AttributeError, UnsupportedOperation) +except ImportError: # Py2 + FILENO_ERRORS = (AttributeError, ) # noqa + def pyimplementation(): """Returns string identifying the current Python implementation.""" @@ -253,17 +259,21 @@ def _create_pidlock(pidfile): def fileno(f): - """Get object fileno, or :const:`None` if not defined.""" - if isinstance(f, int): + if isinstance(f, (int, long)): return f + return f.fileno() + + +def maybe_fileno(f): + """Get object fileno, or :const:`None` if not defined.""" try: - return f.fileno() - except AttributeError: + return fileno(f) + except FILENO_ERRORS: pass def close_open_fds(keep=None): - keep = [fileno(f) for f in keep if fileno(f)] if keep else [] + keep = [maybe_fileno(f) for f in keep if maybe_fileno(f)] if keep else [] for fd in reversed(range(get_fdmax(default=2048))): if fd not in keep: with ignore_errno(errno.EBADF): @@ -299,7 +309,7 @@ class DaemonContext(object): close_open_fds(self.stdfds) for fd in self.stdfds: - self.redirect_to_null(fileno(fd)) + self.redirect_to_null(maybe_fileno(fd)) self._is_open = True __enter__ = open diff --git a/awx/lib/site-packages/celery/result.py b/awx/lib/site-packages/celery/result.py index 1b6af3aec9..acf1ff6cbe 100644 --- a/awx/lib/site-packages/celery/result.py +++ b/awx/lib/site-packages/celery/result.py @@ -350,7 +350,7 @@ class ResultSet(ResultBase): def failed(self): """Did any of the tasks fail? - :returns: :const:`True` if any of the tasks failed. + :returns: :const:`True` if one of the tasks failed. (i.e., raised an exception) """ @@ -359,7 +359,7 @@ class ResultSet(ResultBase): def waiting(self): """Are any of the tasks incomplete? - :returns: :const:`True` if any of the tasks is still + :returns: :const:`True` if one of the tasks are still waiting for execution. """ @@ -368,7 +368,7 @@ class ResultSet(ResultBase): def ready(self): """Did all of the tasks complete? (either by success of failure). - :returns: :const:`True` if all of the tasks been + :returns: :const:`True` if all of the tasks has been executed. """ @@ -435,7 +435,7 @@ class ResultSet(ResultBase): time.sleep(interval) elapsed += interval if timeout and elapsed >= timeout: - raise TimeoutError("The operation timed out") + raise TimeoutError('The operation timed out') def get(self, timeout=None, propagate=True, interval=0.5): """See :meth:`join` @@ -694,7 +694,7 @@ class EagerResult(AsyncResult): self._state = states.REVOKED def __repr__(self): - return "<EagerResult: %s>" % self.id + return '<EagerResult: %s>' % self.id @property def result(self): diff --git a/awx/lib/site-packages/celery/schedules.py b/awx/lib/site-packages/celery/schedules.py index ca0e3aa314..cc7353d02c 100644 --- a/awx/lib/site-packages/celery/schedules.py +++ b/awx/lib/site-packages/celery/schedules.py @@ -379,7 +379,11 @@ class crontab(schedule): flag = (datedata.dom == len(days_of_month) or day_out_of_range(datedata.year, months_of_year[datedata.moy], - days_of_month[datedata.dom])) + days_of_month[datedata.dom]) or + (self.maybe_make_aware(datetime(datedata.year, + months_of_year[datedata.moy], + days_of_month[datedata.dom])) < last_run_at)) + if flag: datedata.dom = 0 datedata.moy += 1 @@ -449,10 +453,11 @@ class crontab(schedule): self._orig_day_of_month, self._orig_month_of_year), None) - def remaining_estimate(self, last_run_at, tz=None): + def remaining_delta(self, last_run_at, tz=None): """Returns when the periodic task should run next as a timedelta.""" tz = tz or self.tz last_run_at = self.maybe_make_aware(last_run_at) + now = self.maybe_make_aware(self.now()) dow_num = last_run_at.isoweekday() % 7 # Sunday is day 0, not day 7 execute_this_date = (last_run_at.month in self.month_of_year and @@ -460,6 +465,9 @@ class crontab(schedule): dow_num in self.day_of_week) execute_this_hour = (execute_this_date and + last_run_at.day == now.day and + last_run_at.month == now.month and + last_run_at.year == now.year and last_run_at.hour in self.hour and last_run_at.minute < max(self.minute)) @@ -499,10 +507,11 @@ class crontab(schedule): else: delta = self._delta_to_next(last_run_at, next_hour, next_minute) + return self.to_local(last_run_at), delta, self.to_local(now) - now = self.maybe_make_aware(self.now()) - return remaining(self.to_local(last_run_at), delta, - self.to_local(now)) + def remaining_estimate(self, last_run_at): + """Returns when the periodic task should run next as a timedelta.""" + return remaining(*self.remaining_delta(last_run_at)) def is_due(self, last_run_at): """Returns tuple of two items `(is_due, next_time_to_run)`, diff --git a/awx/lib/site-packages/celery/task/trace.py b/awx/lib/site-packages/celery/task/trace.py index ab1a4d3e08..6a2a3bef88 100644 --- a/awx/lib/site-packages/celery/task/trace.py +++ b/awx/lib/site-packages/celery/task/trace.py @@ -30,7 +30,10 @@ from celery.app import set_default_app from celery.app.task import Task as BaseTask, Context from celery.datastructures import ExceptionInfo from celery.exceptions import Ignore, RetryTaskError -from celery.utils.serialization import get_pickleable_exception +from celery.utils.serialization import ( + get_pickleable_exception, + get_pickleable_etype, +) from celery.utils.log import get_logger _logger = get_logger(__name__) @@ -128,7 +131,9 @@ class TraceInfo(object): type_, _, tb = sys.exc_info() try: exc = self.retval - einfo = ExceptionInfo((type_, get_pickleable_exception(exc), tb)) + einfo = ExceptionInfo() + einfo.exception = get_pickleable_exception(einfo.exception) + einfo.type = get_pickleable_etype(einfo.type) if store_errors: task.backend.mark_as_failure(req.id, exc, einfo.traceback) task.on_failure(exc, req.id, req.args, req.kwargs, einfo) diff --git a/awx/lib/site-packages/celery/tests/backends/test_base.py b/awx/lib/site-packages/celery/tests/backends/test_base.py index cae67425c8..b04919c3b4 100644 --- a/awx/lib/site-packages/celery/tests/backends/test_base.py +++ b/awx/lib/site-packages/celery/tests/backends/test_base.py @@ -11,8 +11,7 @@ from celery import current_app from celery.result import AsyncResult, GroupResult from celery.utils import serialization from celery.utils.serialization import subclass_exception -from celery.utils.serialization import \ - find_nearest_pickleable_exception as fnpe +from celery.utils.serialization import find_pickleable_exception as fnpe from celery.utils.serialization import UnpickleableExceptionWrapper from celery.utils.serialization import get_pickleable_exception as gpe diff --git a/awx/lib/site-packages/celery/tests/backends/test_mongodb.py b/awx/lib/site-packages/celery/tests/backends/test_mongodb.py index 8980176581..3c15ab40db 100644 --- a/awx/lib/site-packages/celery/tests/backends/test_mongodb.py +++ b/awx/lib/site-packages/celery/tests/backends/test_mongodb.py @@ -287,7 +287,7 @@ class test_MongoBackend(AppCase): mock_database.__getitem__.assert_called_once_with( MONGODB_COLLECTION) mock_collection.remove.assert_called_once_with( - {'_id': sentinel.task_id}, safe=True) + {'_id': sentinel.task_id}) @patch('celery.backends.mongodb.MongoBackend._get_database') def test_cleanup(self, mock_get_database): diff --git a/awx/lib/site-packages/celery/tests/tasks/test_canvas.py b/awx/lib/site-packages/celery/tests/tasks/test_canvas.py index 8011a80de8..a56033ddd2 100644 --- a/awx/lib/site-packages/celery/tests/tasks/test_canvas.py +++ b/awx/lib/site-packages/celery/tests/tasks/test_canvas.py @@ -134,6 +134,11 @@ class test_chain(Case): self.assertEqual(res.parent.parent.get(), 8) self.assertIsNone(res.parent.parent.parent) + def test_accepts_generator_argument(self): + x = chain(add.s(i) for i in range(10)) + self.assertTrue(x.tasks[0].type, add) + self.assertTrue(x.type) + class test_group(Case): diff --git a/awx/lib/site-packages/celery/tests/tasks/test_tasks.py b/awx/lib/site-packages/celery/tests/tasks/test_tasks.py index dd53e50faf..65e6e77862 100644 --- a/awx/lib/site-packages/celery/tests/tasks/test_tasks.py +++ b/awx/lib/site-packages/celery/tests/tasks/test_tasks.py @@ -1,6 +1,8 @@ from __future__ import absolute_import from __future__ import with_statement +import time + from datetime import datetime, timedelta from functools import wraps from mock import patch @@ -616,6 +618,14 @@ def monthly(): pass +@periodic_task(run_every=crontab(hour=22, + day_of_week='*', + month_of_year='2', + day_of_month='26,27,28')) +def monthly_moy(): + pass + + @periodic_task(run_every=crontab(hour=7, minute=30, day_of_week='thursday', day_of_month='8-14', @@ -1212,6 +1222,40 @@ class test_crontab_is_due(Case): self.assertFalse(due) self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60) + @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0)) + def test_monthly_moy_execution_is_due(self): + due, remaining = monthly_moy.run_every.is_due( + datetime(2013, 7, 4, 10, 0)) + self.assertTrue(due) + self.assertEqual(remaining, 60.) + + @patch_crontab_nowfun(monthly_moy, datetime(2013, 6, 28, 14, 30)) + def test_monthly_moy_execution_is_not_due(self): + due, remaining = monthly_moy.run_every.is_due( + datetime(2013, 6, 28, 22, 14)) + self.assertFalse(due) + attempt = ( + time.mktime(datetime(2014, 2, 26, 22, 0).timetuple()) - + time.mktime(datetime(2013, 6, 28, 14, 30).timetuple()) - + 60 * 60 + ) + self.assertEqual(remaining, attempt) + + @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0)) + def test_monthly_moy_execution_is_due2(self): + due, remaining = monthly_moy.run_every.is_due( + datetime(2013, 2, 28, 10, 0)) + self.assertTrue(due) + self.assertEqual(remaining, 60.) + + @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 21, 0)) + def test_monthly_moy_execution_is_not_due2(self): + due, remaining = monthly_moy.run_every.is_due( + datetime(2013, 6, 28, 22, 14)) + self.assertFalse(due) + attempt = 60 * 60 + self.assertEqual(remaining, attempt) + @patch_crontab_nowfun(yearly, datetime(2010, 3, 11, 7, 30)) def test_yearly_execution_is_due(self): due, remaining = yearly.run_every.is_due( diff --git a/awx/lib/site-packages/celery/tests/utilities/test_encoding.py b/awx/lib/site-packages/celery/tests/utilities/test_encoding.py index a7bd28a411..d9885b857f 100644 --- a/awx/lib/site-packages/celery/tests/utilities/test_encoding.py +++ b/awx/lib/site-packages/celery/tests/utilities/test_encoding.py @@ -1,9 +1,5 @@ from __future__ import absolute_import -import sys - -from nose import SkipTest - from celery.utils import encoding from celery.tests.utils import Case @@ -15,17 +11,6 @@ class test_encoding(Case): self.assertTrue(encoding.safe_str('foo')) self.assertTrue(encoding.safe_str(u'foo')) - def test_safe_str_UnicodeDecodeError(self): - if sys.version_info >= (3, 0): - raise SkipTest('py3k: not relevant') - - class foo(unicode): - - def encode(self, *args, **kwargs): - raise UnicodeDecodeError('foo') - - self.assertIn('<Unrepresentable', encoding.safe_str(foo())) - def test_safe_repr(self): self.assertTrue(encoding.safe_repr(object())) diff --git a/awx/lib/site-packages/celery/tests/utilities/test_term.py b/awx/lib/site-packages/celery/tests/utilities/test_term.py index ce1285701d..5e6ca50631 100644 --- a/awx/lib/site-packages/celery/tests/utilities/test_term.py +++ b/awx/lib/site-packages/celery/tests/utilities/test_term.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import +from kombu.utils.encoding import str_t + from celery.utils import term from celery.utils.term import colored, fg @@ -38,7 +40,7 @@ class test_colored(Case): self.assertTrue(str(colored().iwhite('f'))) self.assertTrue(str(colored().reset('f'))) - self.assertTrue(str(colored().green(u'∂bar'))) + self.assertTrue(str_t(colored().green(u'∂bar'))) self.assertTrue( colored().red(u'éefoo') + colored().green(u'∂bar')) diff --git a/awx/lib/site-packages/celery/tests/worker/test_worker.py b/awx/lib/site-packages/celery/tests/worker/test_worker.py index 006488dbf5..044f033073 100644 --- a/awx/lib/site-packages/celery/tests/worker/test_worker.py +++ b/awx/lib/site-packages/celery/tests/worker/test_worker.py @@ -967,10 +967,12 @@ class test_WorkController(AppCase): except ImportError: raise SkipTest('multiprocessing not supported') self.assertIsInstance(worker.ready_queue, AsyncTaskBucket) - self.assertFalse(worker.mediator) - self.assertNotEqual(worker.ready_queue.put, worker.process_task) + # XXX disabled until 3.1 + #self.assertFalse(worker.mediator) + #self.assertNotEqual(worker.ready_queue.put, worker.process_task) def test_disable_rate_limits_processes(self): + raise SkipTest('disabled until v3.1') try: worker = self.create_worker(disable_rate_limits=True, use_eventloop=False, @@ -1058,6 +1060,7 @@ class test_WorkController(AppCase): self.assertTrue(w.disable_rate_limits) def test_Queues_pool_no_sem(self): + raise SkipTest('disabled until v3.1') w = Mock() w.pool_cls.uses_semaphore = False Queues(w).create(w) @@ -1086,6 +1089,7 @@ class test_WorkController(AppCase): w.hub.on_init = [] w.pool_cls = Mock() P = w.pool_cls.return_value = Mock() + P._cache = {} P.timers = {Mock(): 30} w.use_eventloop = True w.consumer.restart_count = -1 @@ -1105,23 +1109,13 @@ class test_WorkController(AppCase): cbs['on_process_down'](w) hub.remove.assert_called_with(w.sentinel) + w.pool._tref_for_id = {} + result = Mock() - tref = result._tref cbs['on_timeout_cancel'](result) - tref.cancel.assert_called_with() cbs['on_timeout_cancel'](result) # no more tref - cbs['on_timeout_set'](result, 10, 20) - tsoft, callback = hub.timer.apply_after.call_args[0] - callback() - - cbs['on_timeout_set'](result, 10, None) - tsoft, callback = hub.timer.apply_after.call_args[0] - callback() - cbs['on_timeout_set'](result, None, 10) - cbs['on_timeout_set'](result, None, None) - with self.assertRaises(WorkerLostError): P.did_start_ok.return_value = False w.consumer.restart_count = 0 diff --git a/awx/lib/site-packages/celery/utils/imports.py b/awx/lib/site-packages/celery/utils/imports.py index 0ea1d77523..e46462663e 100644 --- a/awx/lib/site-packages/celery/utils/imports.py +++ b/awx/lib/site-packages/celery/utils/imports.py @@ -28,15 +28,18 @@ class NotAPackage(Exception): if sys.version_info >= (3, 3): # pragma: no cover def qualname(obj): - return obj.__qualname__ - + if not hasattr(obj, '__name__') and hasattr(obj, '__class__'): + obj = obj.__class__ + q = getattr(obj, '__qualname__', None) + if '.' not in q: + q = '.'.join((obj.__module__, q)) + return q else: def qualname(obj): # noqa if not hasattr(obj, '__name__') and hasattr(obj, '__class__'): return qualname(obj.__class__) - - return '%s.%s' % (obj.__module__, obj.__name__) + return '.'.join((obj.__module__, obj.__name__)) def instantiate(name, *args, **kwargs): diff --git a/awx/lib/site-packages/celery/utils/serialization.py b/awx/lib/site-packages/celery/utils/serialization.py index 0a7cdf25c7..0cce7e971a 100644 --- a/awx/lib/site-packages/celery/utils/serialization.py +++ b/awx/lib/site-packages/celery/utils/serialization.py @@ -50,7 +50,8 @@ else: return type(name, (parent,), {'__module__': module}) -def find_nearest_pickleable_exception(exc): +def find_pickleable_exception(exc, loads=pickle.loads, + dumps=pickle.dumps): """With an exception instance, iterate over its super classes (by mro) and find the first super exception that is pickleable. It does not go below :exc:`Exception` (i.e. it skips :exc:`Exception`, @@ -65,7 +66,19 @@ def find_nearest_pickleable_exception(exc): :rtype :exc:`Exception`: """ - cls = exc.__class__ + exc_args = getattr(exc, 'args', []) + for supercls in itermro(exc.__class__, unwanted_base_classes): + try: + superexc = supercls(*exc_args) + loads(dumps(superexc)) + except: + pass + else: + return superexc +find_nearest_pickleable_exception = find_pickleable_exception # XXX compat + + +def itermro(cls, stop): getmro_ = getattr(cls, 'mro', None) # old-style classes doesn't have mro() @@ -77,18 +90,11 @@ def find_nearest_pickleable_exception(exc): getmro_ = lambda: inspect.getmro(cls) for supercls in getmro_(): - if supercls in unwanted_base_classes: + if supercls in stop: # only BaseException and object, from here on down, # we don't care about these. return - try: - exc_args = getattr(exc, 'args', []) - superexc = supercls(*exc_args) - pickle.loads(pickle.dumps(superexc)) - except: - pass - else: - return superexc + yield supercls def create_exception_cls(name, module, parent=None): @@ -165,12 +171,19 @@ def get_pickleable_exception(exc): pass else: return exc - nearest = find_nearest_pickleable_exception(exc) + nearest = find_pickleable_exception(exc) if nearest: return nearest return UnpickleableExceptionWrapper.from_exception(exc) +def get_pickleable_etype(cls, loads=pickle.loads, dumps=pickle.dumps): + try: + loads(dumps(cls)) + except: + return Exception + + def get_pickled_exception(exc): """Get original exception from exception pickled using :meth:`get_pickleable_exception`.""" diff --git a/awx/lib/site-packages/celery/utils/threads.py b/awx/lib/site-packages/celery/utils/threads.py index 8f88aabf84..9826363491 100644 --- a/awx/lib/site-packages/celery/utils/threads.py +++ b/awx/lib/site-packages/celery/utils/threads.py @@ -18,29 +18,26 @@ from celery.utils.compat import THREAD_TIMEOUT_MAX USE_FAST_LOCALS = os.environ.get('USE_FAST_LOCALS') PY3 = sys.version_info[0] == 3 +NEW_EVENT = (sys.version_info[0] == 3) and (sys.version_info[1] >= 3) _Thread = threading.Thread -_Event = threading.Event if PY3 else threading._Event +_Event = threading.Event if NEW_EVENT else threading._Event active_count = (getattr(threading, 'active_count', None) or threading.activeCount) -class Event(_Event): +if sys.version_info < (2, 6): - if not hasattr(_Event, 'is_set'): # pragma: no cover + class Event(_Event): # pragma: no cover is_set = _Event.isSet - -class Thread(_Thread): - - if not hasattr(_Thread, 'is_alive'): # pragma: no cover + class Thread(_Thread): # pragma: no cover is_alive = _Thread.isAlive - - if not hasattr(_Thread, 'daemon'): # pragma: no cover daemon = property(_Thread.isDaemon, _Thread.setDaemon) - - if not hasattr(_Thread, 'name'): # pragma: no cover name = property(_Thread.getName, _Thread.setName) +else: + Event = _Event + Thread = _Thread class bgThread(Thread): diff --git a/awx/lib/site-packages/celery/worker/__init__.py b/awx/lib/site-packages/celery/worker/__init__.py index 10faa394de..04f7bcd2e4 100644 --- a/awx/lib/site-packages/celery/worker/__init__.py +++ b/awx/lib/site-packages/celery/worker/__init__.py @@ -19,6 +19,7 @@ import time import traceback from functools import partial +from weakref import WeakValueDictionary from billiard.exceptions import WorkerLostError from billiard.util import Finalize @@ -26,6 +27,7 @@ from kombu.syn import detect_environment from celery import concurrency as _concurrency from celery import platforms +from celery import signals from celery.app import app_or_default from celery.app.abstract import configurated, from_config from celery.exceptions import SystemTerminate, TaskRevokedError @@ -105,6 +107,7 @@ class Pool(bootsteps.StartStopComponent): add_reader = hub.add_reader remove = hub.remove now = time.time + cache = pool._pool._cache # did_start_ok will verify that pool processes were able to start, # but this will only work the first time we start, as @@ -120,25 +123,58 @@ class Pool(bootsteps.StartStopComponent): for handler, interval in pool.timers.iteritems(): hub.timer.apply_interval(interval * 1000.0, handler) - def on_timeout_set(R, soft, hard): + trefs = pool._tref_for_id = WeakValueDictionary() + + def _discard_tref(job): + try: + tref = trefs.pop(job) + tref.cancel() + del(tref) + except (KeyError, AttributeError): + pass # out of scope + + def _on_hard_timeout(job): + try: + result = cache[job] + except KeyError: + pass # job ready + else: + on_hard_timeout(result) + finally: + # remove tref + _discard_tref(job) + + def _on_soft_timeout(job, soft, hard, hub): + if hard: + trefs[job] = apply_at( + now() + (hard - soft), + _on_hard_timeout, (job, ), + ) + try: + result = cache[job] + except KeyError: + pass # job ready + else: + on_soft_timeout(result) + finally: + if not hard: + # remove tref + _discard_tref(job) - def _on_soft_timeout(): - if hard: - R._tref = apply_at(now() + (hard - soft), - on_hard_timeout, (R, )) - on_soft_timeout(R) + def on_timeout_set(R, soft, hard): if soft: - R._tref = apply_after(soft * 1000.0, _on_soft_timeout) + trefs[R._job] = apply_after( + soft * 1000.0, + _on_soft_timeout, (R._job, soft, hard, hub), + ) elif hard: - R._tref = apply_after(hard * 1000.0, - on_hard_timeout, (R, )) + trefs[R._job] = apply_after( + hard * 1000.0, + _on_hard_timeout, (R._job, ) + ) - def on_timeout_cancel(result): - try: - result._tref.cancel() - delattr(result, '_tref') - except AttributeError: - pass + def on_timeout_cancel(R): + _discard_tref(R._job) pool.init_callbacks( on_process_up=lambda w: add_reader(w.sentinel, maintain_pool), @@ -208,19 +244,18 @@ class Queues(bootsteps.Component): def create(self, w): BucketType = TaskBucket - w.start_mediator = not w.disable_rate_limits + w.start_mediator = w.pool_cls.requires_mediator if not w.pool_cls.rlimit_safe: - w.start_mediator = False BucketType = AsyncTaskBucket process_task = w.process_task if w.use_eventloop: - w.start_mediator = False BucketType = AsyncTaskBucket if w.pool_putlocks and w.pool_cls.uses_semaphore: process_task = w.process_task_sem - if w.disable_rate_limits: + if w.disable_rate_limits or not w.start_mediator: w.ready_queue = FastQueue() - w.ready_queue.put = process_task + if not w.start_mediator: + w.ready_queue.put = process_task else: w.ready_queue = BucketType( task_registry=w.app.tasks, callback=process_task, worker=w, @@ -327,7 +362,10 @@ class WorkController(configurated): self.loglevel = loglevel or self.loglevel self.hostname = hostname or socket.gethostname() self.ready_callback = ready_callback - self._finalize = Finalize(self, self.stop, exitpriority=1) + self._finalize = [ + Finalize(self, self.stop, exitpriority=1), + Finalize(self, self._send_worker_shutdown, exitpriority=10), + ] self.pidfile = pidfile self.pidlock = None # this connection is not established, only used for params @@ -350,6 +388,9 @@ class WorkController(configurated): self.components = [] self.namespace = Namespace(app=self.app).apply(self, **kwargs) + def _send_worker_shutdown(self): + signals.worker_shutdown.send(sender=self) + def start(self): """Starts the workers main loop.""" self._state = self.RUN diff --git a/awx/lib/site-packages/celery/worker/buckets.py b/awx/lib/site-packages/celery/worker/buckets.py index e975328d49..2d7ccc3081 100644 --- a/awx/lib/site-packages/celery/worker/buckets.py +++ b/awx/lib/site-packages/celery/worker/buckets.py @@ -39,6 +39,12 @@ class AsyncTaskBucket(object): self.worker = worker self.buckets = {} self.refresh() + self._queue = Queue() + self._quick_put = self._queue.put + self.get = self._queue.get + + def get(self, *args, **kwargs): + return self._queue.get(*args, **kwargs) def cont(self, request, bucket, tokens): if not bucket.can_consume(tokens): @@ -47,7 +53,7 @@ class AsyncTaskBucket(object): hold * 1000.0, self.cont, (request, bucket, tokens), ) else: - self.callback(request) + self._quick_put(request) def put(self, request): name = request.name @@ -56,7 +62,7 @@ class AsyncTaskBucket(object): except KeyError: bucket = self.add_bucket_for_type(name) if not bucket: - return self.callback(request) + return self._quick_put(request) return self.cont(request, bucket, 1) def add_task_type(self, name): diff --git a/awx/lib/site-packages/celery/worker/consumer.py b/awx/lib/site-packages/celery/worker/consumer.py index aed01f3e02..4d811ffed2 100644 --- a/awx/lib/site-packages/celery/worker/consumer.py +++ b/awx/lib/site-packages/celery/worker/consumer.py @@ -80,6 +80,8 @@ import threading from time import sleep from Queue import Empty +from billiard.common import restart_state +from billiard.exceptions import RestartFreqExceeded from kombu.syn import _detect_environment from kombu.utils.encoding import safe_repr, safe_str, bytes_t from kombu.utils.eventio import READ, WRITE, ERR @@ -100,6 +102,13 @@ from .bootsteps import StartStopComponent from .control import Panel from .heartbeat import Heart +try: + buffer_t = buffer +except NameError: # pragma: no cover + + class buffer_t(object): # noqa + pass + RUN = 0x1 CLOSE = 0x2 @@ -171,7 +180,7 @@ def debug(msg, *args, **kwargs): def dump_body(m, body): - if isinstance(body, buffer): + if isinstance(body, buffer_t): body = bytes_t(body) return "%s (%sb)" % (text.truncate(safe_repr(body), 1024), len(m.body)) @@ -348,6 +357,7 @@ class Consumer(object): conninfo = self.app.connection() self.connection_errors = conninfo.connection_errors self.channel_errors = conninfo.channel_errors + self._restart_state = restart_state(maxR=5, maxT=1) self._does_info = logger.isEnabledFor(logging.INFO) self.strategies = {} @@ -391,6 +401,11 @@ class Consumer(object): self.restart_count += 1 self.maybe_shutdown() try: + self._restart_state.step() + except RestartFreqExceeded as exc: + crit('Frequent restarts detected: %r', exc, exc_info=1) + sleep(1) + try: self.reset_connection() self.consume_messages() except self.connection_errors + self.channel_errors: @@ -736,6 +751,7 @@ class Consumer(object): # to the current channel. self.ready_queue.clear() self.timer.clear() + state.reserved_requests.clear() # Re-establish the broker connection and setup the task consumer. self.connection = self._open_connection() diff --git a/awx/lib/site-packages/celery/worker/mediator.py b/awx/lib/site-packages/celery/worker/mediator.py index b467b71c74..0e10392797 100644 --- a/awx/lib/site-packages/celery/worker/mediator.py +++ b/awx/lib/site-packages/celery/worker/mediator.py @@ -36,7 +36,7 @@ class WorkerComponent(StartStopComponent): w.mediator = None def include_if(self, w): - return w.start_mediator and not w.use_eventloop + return w.start_mediator def create(self, w): m = w.mediator = self.instantiate(w.mediator_cls, w.ready_queue, diff --git a/awx/lib/site-packages/django_extensions/__init__.py b/awx/lib/site-packages/django_extensions/__init__.py index 3efe475a76..3426dc3739 100644 --- a/awx/lib/site-packages/django_extensions/__init__.py +++ b/awx/lib/site-packages/django_extensions/__init__.py @@ -1,5 +1,5 @@ -VERSION = (1, 1, 1) +VERSION = (1, 2, 0) # Dynamically calculate the version based on VERSION tuple if len(VERSION) > 2 and VERSION[2] is not None: diff --git a/awx/lib/site-packages/django_extensions/admin/__init__.py b/awx/lib/site-packages/django_extensions/admin/__init__.py index 4df241460b..7564ea65e3 100644 --- a/awx/lib/site-packages/django_extensions/admin/__init__.py +++ b/awx/lib/site-packages/django_extensions/admin/__init__.py @@ -9,6 +9,7 @@ # (Michal Salaban) # +import six import operator from six.moves import reduce from django.http import HttpResponse, HttpResponseNotFound @@ -108,8 +109,7 @@ class ForeignKeyAutocompleteAdmin(ModelAdmin): other_qs.dup_select_related(queryset) other_qs = other_qs.filter(reduce(operator.or_, or_queries)) queryset = queryset & other_qs - data = ''.join([u'%s|%s\n' % ( - to_string_function(f), f.pk) for f in queryset]) + data = ''.join([six.u('%s|%s\n' % (to_string_function(f), f.pk)) for f in queryset]) elif object_pk: try: obj = queryset.get(pk=object_pk) @@ -139,7 +139,7 @@ class ForeignKeyAutocompleteAdmin(ModelAdmin): model_name = db_field.rel.to._meta.object_name help_text = self.get_help_text(db_field.name, model_name) if kwargs.get('help_text'): - help_text = u'%s %s' % (kwargs['help_text'], help_text) + help_text = six.u('%s %s' % (kwargs['help_text'], help_text)) kwargs['widget'] = ForeignKeySearchInput(db_field.rel, self.related_search_fields[db_field.name]) kwargs['help_text'] = help_text return super(ForeignKeyAutocompleteAdmin, self).formfield_for_dbfield(db_field, **kwargs) diff --git a/awx/lib/site-packages/django_extensions/admin/widgets.py b/awx/lib/site-packages/django_extensions/admin/widgets.py index feefd406bc..6a2a4af959 100644 --- a/awx/lib/site-packages/django_extensions/admin/widgets.py +++ b/awx/lib/site-packages/django_extensions/admin/widgets.py @@ -1,3 +1,4 @@ +import six import django from django import forms from django.conf import settings @@ -67,7 +68,7 @@ class ForeignKeySearchInput(ForeignKeyRawIdWidget): if value: label = self.label_for_value(value) else: - label = u'' + label = six.u('') try: admin_media_prefix = settings.ADMIN_MEDIA_PREFIX @@ -92,4 +93,4 @@ class ForeignKeySearchInput(ForeignKeyRawIdWidget): 'django_extensions/widgets/foreignkey_searchinput.html', ), context)) output.reverse() - return mark_safe(u''.join(output)) + return mark_safe(six.u(''.join(output))) diff --git a/awx/lib/site-packages/django_extensions/db/fields/__init__.py b/awx/lib/site-packages/django_extensions/db/fields/__init__.py index 8a17c48064..337ddd2c27 100644 --- a/awx/lib/site-packages/django_extensions/db/fields/__init__.py +++ b/awx/lib/site-packages/django_extensions/db/fields/__init__.py @@ -3,13 +3,13 @@ Django Extensions additional model fields """ import re import six - try: import uuid - assert uuid + HAS_UUID = True except ImportError: - from django_extensions.utils import uuid + HAS_UUID = False +from django.core.exceptions import ImproperlyConfigured from django.template.defaultfilters import slugify from django.db.models import DateTimeField, CharField, SlugField @@ -56,7 +56,7 @@ class AutoSlugField(SlugField): raise ValueError("missing 'populate_from' argument") else: self._populate_from = populate_from - self.separator = kwargs.pop('separator', u'-') + self.separator = kwargs.pop('separator', six.u('-')) self.overwrite = kwargs.pop('overwrite', False) self.allow_duplicates = kwargs.pop('allow_duplicates', False) super(AutoSlugField, self).__init__(*args, **kwargs) @@ -221,13 +221,15 @@ class UUIDVersionError(Exception): class UUIDField(CharField): """ UUIDField - By default uses UUID version 4 (generate from host ID, sequence number and current time) + By default uses UUID version 4 (randomly generated UUID). - The field support all uuid versions which are natively supported by the uuid python module. + The field support all uuid versions which are natively supported by the uuid python module, except version 2. For more information see: http://docs.python.org/lib/module-uuid.html """ - def __init__(self, verbose_name=None, name=None, auto=True, version=1, node=None, clock_seq=None, namespace=None, **kwargs): + def __init__(self, verbose_name=None, name=None, auto=True, version=4, node=None, clock_seq=None, namespace=None, **kwargs): + if not HAS_UUID: + raise ImproperlyConfigured("'uuid' module is required for UUIDField. (Do you have Python 2.5 or higher installed ?)") kwargs.setdefault('max_length', 36) if auto: self.empty_strings_allowed = False @@ -244,17 +246,6 @@ class UUIDField(CharField): def get_internal_type(self): return CharField.__name__ - def contribute_to_class(self, cls, name): - if self.primary_key: - assert not cls._meta.has_auto_field, "A model can't have more than one AutoField: %s %s %s; have %s" % ( - self, cls, name, cls._meta.auto_field - ) - super(UUIDField, self).contribute_to_class(cls, name) - cls._meta.has_auto_field = True - cls._meta.auto_field = self - else: - super(UUIDField, self).contribute_to_class(cls, name) - def create_uuid(self): if not self.version or self.version == 4: return uuid.uuid4() @@ -277,7 +268,7 @@ class UUIDField(CharField): return value else: if self.auto and not value: - value = six.u(self.create_uuid()) + value = force_unicode(self.create_uuid()) setattr(model_instance, self.attname, value) return value diff --git a/awx/lib/site-packages/django_extensions/db/fields/encrypted.py b/awx/lib/site-packages/django_extensions/db/fields/encrypted.py index 5d45b6912a..f897598a96 100644 --- a/awx/lib/site-packages/django_extensions/db/fields/encrypted.py +++ b/awx/lib/site-packages/django_extensions/db/fields/encrypted.py @@ -21,8 +21,11 @@ class BaseEncryptedField(models.Field): def __init__(self, *args, **kwargs): if not hasattr(settings, 'ENCRYPTED_FIELD_KEYS_DIR'): - raise ImproperlyConfigured('You must set the ENCRYPTED_FIELD_KEYS_DIR setting to your Keyczar keys directory.') - self.crypt = keyczar.Crypter.Read(settings.ENCRYPTED_FIELD_KEYS_DIR) + raise ImproperlyConfigured('You must set the ENCRYPTED_FIELD_KEYS_DIR ' + 'setting to your Keyczar keys directory.') + + crypt_class = self.get_crypt_class() + self.crypt = crypt_class.Read(settings.ENCRYPTED_FIELD_KEYS_DIR) # Encrypted size is larger than unencrypted self.unencrypted_length = max_length = kwargs.get('max_length', None) @@ -34,6 +37,32 @@ class BaseEncryptedField(models.Field): super(BaseEncryptedField, self).__init__(*args, **kwargs) + def get_crypt_class(self): + """ + Get the Keyczar class to use. + + The class can be customized with the ENCRYPTED_FIELD_MODE setting. By default, + this setting is DECRYPT_AND_ENCRYPT. Set this to ENCRYPT to disable decryption. + This is necessary if you are only providing public keys to Keyczar. + + Returns: + keyczar.Encrypter if ENCRYPTED_FIELD_MODE is ENCRYPT. + keyczar.Crypter if ENCRYPTED_FIELD_MODE is DECRYPT_AND_ENCRYPT. + + Override this method to customize the type of Keyczar class returned. + """ + + crypt_type = getattr(settings, 'ENCRYPTED_FIELD_MODE', 'DECRYPT_AND_ENCRYPT') + if crypt_type == 'ENCRYPT': + crypt_class_name = 'Encrypter' + elif crypt_type == 'DECRYPT_AND_ENCRYPT': + crypt_class_name = 'Crypter' + else: + raise ImproperlyConfigured( + 'ENCRYPTED_FIELD_MODE must be either DECRYPT_AND_ENCRYPT ' + 'or ENCRYPT, not %s.' % crypt_type) + return getattr(keyczar, crypt_class_name) + def to_python(self, value): if isinstance(self.crypt.primary_key, keyczar.keys.RsaPublicKey): retval = value @@ -64,9 +93,8 @@ class BaseEncryptedField(models.Field): return value -class EncryptedTextField(BaseEncryptedField): - __metaclass__ = models.SubfieldBase - +class EncryptedTextField(six.with_metaclass(models.SubfieldBase, + BaseEncryptedField)): def get_internal_type(self): return 'TextField' @@ -85,9 +113,8 @@ class EncryptedTextField(BaseEncryptedField): return (field_class, args, kwargs) -class EncryptedCharField(BaseEncryptedField): - __metaclass__ = models.SubfieldBase - +class EncryptedCharField(six.with_metaclass(models.SubfieldBase, + BaseEncryptedField)): def __init__(self, *args, **kwargs): super(EncryptedCharField, self).__init__(*args, **kwargs) @@ -107,4 +134,3 @@ class EncryptedCharField(BaseEncryptedField): args, kwargs = introspector(self) # That's our definition! return (field_class, args, kwargs) - diff --git a/awx/lib/site-packages/django_extensions/db/fields/json.py b/awx/lib/site-packages/django_extensions/db/fields/json.py index e1b1b51602..51d5b1dd52 100644 --- a/awx/lib/site-packages/django_extensions/db/fields/json.py +++ b/awx/lib/site-packages/django_extensions/db/fields/json.py @@ -58,16 +58,13 @@ class JSONList(list): return dumps(self) -class JSONField(models.TextField): +class JSONField(six.with_metaclass(models.SubfieldBase, models.TextField)): """JSONField is a generic textfield that neatly serializes/unserializes JSON objects seamlessly. Main thingy must be a dict object.""" - # Used so to_python() is called - __metaclass__ = models.SubfieldBase - def __init__(self, *args, **kwargs): - default = kwargs.get('default') - if not default: + default = kwargs.get('default', None) + if default is None: kwargs['default'] = '{}' elif isinstance(default, (list, dict)): kwargs['default'] = dumps(default) diff --git a/awx/lib/site-packages/django_extensions/future_1_5.py b/awx/lib/site-packages/django_extensions/future_1_5.py new file mode 100644 index 0000000000..d7144ca383 --- /dev/null +++ b/awx/lib/site-packages/django_extensions/future_1_5.py @@ -0,0 +1,16 @@ +""" +A forwards compatibility module. + +Implements some features of Django 1.5 related to the 'Custom User Model' feature +when the application is run with a lower version of Django. +""" +from __future__ import unicode_literals + +from django.contrib.auth.models import User + +User.USERNAME_FIELD = "username" +User.get_username = lambda self: self.username + + +def get_user_model(): + return User diff --git a/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py b/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py index 299a0201fe..5caf40efa3 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py +++ b/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py @@ -71,9 +71,9 @@ def orm_item_locator(orm_obj): for key in clean_dict: v = clean_dict[key] if v is not None and not isinstance(v, (six.string_types, six.integer_types, float, datetime.datetime)): - clean_dict[key] = u"%s" % v + clean_dict[key] = six.u("%s" % v) - output = """ locate_object(%s, "%s", %s, "%s", %s, %s ) """ % ( + output = """ importer.locate_object(%s, "%s", %s, "%s", %s, %s ) """ % ( original_class, original_pk_name, the_class, pk_name, pk_value, clean_dict ) @@ -264,7 +264,7 @@ class InstanceCode(Code): # Print the save command for our new object # e.g. model_name_35.save() if code_lines: - code_lines.append("%s = save_or_locate(%s)\n" % (self.variable_name, self.variable_name)) + code_lines.append("%s = importer.save_or_locate(%s)\n" % (self.variable_name, self.variable_name)) code_lines += self.get_many_to_many_lines(force=force) @@ -499,7 +499,6 @@ class Script(Code): code.insert(2, "") for key, value in self.context["__extra_imports"].items(): code.insert(2, " from %s import %s" % (value, key)) - code.insert(2 + len(self.context["__extra_imports"]), self.locate_object_function) return code @@ -513,11 +512,17 @@ class Script(Code): # This file has been automatically generated. # Instead of changing it, create a file called import_helper.py -# which this script has hooks to. +# and put there a class called ImportHelper(object) in it. # -# On that file, don't forget to add the necessary Django imports -# and take a look at how locate_object() and save_or_locate() -# are implemented here and expected to behave. +# This class will be specially casted so that instead of extending object, +# it will actually extend the class BasicImportHelper() +# +# That means you just have to overload the methods you want to +# change, leaving the other ones inteact. +# +# Something that you might want to do is use transactions, for example. +# +# Also, don't forget to add the necessary Django imports. # # This file was generated with the following command: # %s @@ -530,24 +535,31 @@ class Script(Code): # you must make sure ./some_folder/__init__.py exists # and run ./manage.py runscript some_folder.some_script +from django.db import transaction -IMPORT_HELPER_AVAILABLE = False -try: - import import_helper - IMPORT_HELPER_AVAILABLE = True -except ImportError: - pass +class BasicImportHelper(object): -import datetime -from decimal import Decimal -from django.contrib.contenttypes.models import ContentType + def pre_import(self): + pass -def run(): + # You probably want to uncomment on of these two lines + # @transaction.atomic # Django 1.6 + # @transaction.commit_on_success # Django <1.6 + def run_import(self, import_data): + import_data() -""" % " ".join(sys.argv) + def post_import(self): + pass + + def locate_similar(self, current_object, search_data): + #you will probably want to call this method from save_or_locate() + #example: + #new_obj = self.locate_similar(the_obj, {"national_id": the_obj.national_id } ) - locate_object_function = """ - def locate_object(original_class, original_pk_name, the_class, pk_name, pk_value, obj_content): + the_obj = current_object.__class__.objects.get(**search_data) + return the_obj + + def locate_object(self, original_class, original_pk_name, the_class, pk_name, pk_value, obj_content): #You may change this function to do specific lookup for specific objects # #original_class class of the django orm's object that needs to be located @@ -571,22 +583,55 @@ def run(): #if the_class == StaffGroup: # pk_value=8 - - if IMPORT_HELPER_AVAILABLE and hasattr(import_helper, "locate_object"): - return import_helper.locate_object(original_class, original_pk_name, the_class, pk_name, pk_value, obj_content) - search_data = { pk_name: pk_value } - the_obj =the_class.objects.get(**search_data) + the_obj = the_class.objects.get(**search_data) + #print(the_obj) return the_obj - def save_or_locate(the_obj): - if IMPORT_HELPER_AVAILABLE and hasattr(import_helper, "save_or_locate"): - the_obj = import_helper.save_or_locate(the_obj) - else: + + def save_or_locate(self, the_obj): + #change this if you want to locate the object in the database + try: the_obj.save() + except: + print("---------------") + print("Error saving the following object:") + print(the_obj.__class__) + print(" ") + print(the_obj.__dict__) + print(" ") + print(the_obj) + print(" ") + print("---------------") + + raise return the_obj -""" + +importer = None +try: + import import_helper + #we need this so ImportHelper can extend BasicImportHelper, although import_helper.py + #has no knowlodge of this class + importer = type("DynamicImportHelper", (import_helper.ImportHelper, BasicImportHelper ) , {} )() +except ImportError as e: + if str(e) == "No module named import_helper": + importer = BasicImportHelper() + else: + raise + +import datetime +from decimal import Decimal +from django.contrib.contenttypes.models import ContentType + +def run(): + importer.pre_import() + importer.run_import(import_data) + importer.post_import() + +def import_data(): + +""" % " ".join(sys.argv) # HELPER FUNCTIONS diff --git a/awx/lib/site-packages/django_extensions/management/commands/export_emails.py b/awx/lib/site-packages/django_extensions/management/commands/export_emails.py index 7766a1b6e2..05a8689b05 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/export_emails.py +++ b/awx/lib/site-packages/django_extensions/management/commands/export_emails.py @@ -1,8 +1,13 @@ from django.core.management.base import BaseCommand, CommandError -from django.contrib.auth.models import User, Group +try: + from django.contrib.auth import get_user_model # Django 1.5 +except ImportError: + from django_extensions.future_1_5 import get_user_model +from django.contrib.auth.models import Group from optparse import make_option from sys import stdout from csv import writer +import six FORMATS = [ 'address', @@ -15,7 +20,7 @@ FORMATS = [ def full_name(first_name, last_name, username, **extra): - name = u" ".join(n for n in [first_name, last_name] if n) + name = six.u(" ").join(n for n in [first_name, last_name] if n) if not name: return username return name @@ -42,7 +47,7 @@ class Command(BaseCommand): raise CommandError("extra arguments supplied") group = options['group'] if group and not Group.objects.filter(name=group).count() == 1: - names = u"', '".join(g['name'] for g in Group.objects.values('name')).encode('utf-8') + names = six.u("', '").join(g['name'] for g in Group.objects.values('name')).encode('utf-8') if names: names = "'" + names + "'." raise CommandError("Unknown group '" + group + "'. Valid group names are: " + names) @@ -51,6 +56,7 @@ class Command(BaseCommand): else: outfile = stdout + User = get_user_model() qs = User.objects.all().order_by('last_name', 'first_name', 'username', 'email') if group: qs = qs.filter(group__name=group).distinct() @@ -61,15 +67,15 @@ class Command(BaseCommand): """simple single entry per line in the format of: "full name" <my@address.com>; """ - out.write(u"\n".join(u'"%s" <%s>;' % (full_name(**ent), ent['email']) - for ent in qs).encode(self.encoding)) + out.write(six.u("\n").join(six.u('"%s" <%s>;' % (full_name(**ent), ent['email'])) + for ent in qs).encode(self.encoding)) out.write("\n") def emails(self, qs, out): """simpler single entry with email only in the format of: my@address.com, """ - out.write(u",\n".join(u'%s' % (ent['email']) for ent in qs).encode(self.encoding)) + out.write(six.u(",\n").join(six.u('%s' % (ent['email'])) for ent in qs).encode(self.encoding)) out.write("\n") def google(self, qs, out): diff --git a/awx/lib/site-packages/django_extensions/management/commands/graph_models.py b/awx/lib/site-packages/django_extensions/management/commands/graph_models.py index 4ff03bb3aa..59403e248c 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/graph_models.py +++ b/awx/lib/site-packages/django_extensions/management/commands/graph_models.py @@ -56,7 +56,7 @@ class Command(BaseCommand): vizdata = ' '.join(dotdata.split("\n")).strip().encode('utf-8') version = pygraphviz.__version__.rstrip("-svn") try: - if [int(v) for v in version.split('.')] < (0, 36): + if tuple(int(v) for v in version.split('.')) < (0, 36): # HACK around old/broken AGraph before version 0.36 (ubuntu ships with this old version) import tempfile tmpfile = tempfile.NamedTemporaryFile() diff --git a/awx/lib/site-packages/django_extensions/management/commands/passwd.py b/awx/lib/site-packages/django_extensions/management/commands/passwd.py index 1bd58ab524..6f084b143d 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/passwd.py +++ b/awx/lib/site-packages/django_extensions/management/commands/passwd.py @@ -1,5 +1,8 @@ from django.core.management.base import BaseCommand, CommandError -from django.contrib.auth.models import User +try: + from django.contrib.auth import get_user_model # Django 1.5 +except ImportError: + from django_extensions.future_1_5 import get_user_model import getpass @@ -17,6 +20,7 @@ class Command(BaseCommand): else: username = getpass.getuser() + User = get_user_model() try: u = User.objects.get(username=username) except User.DoesNotExist: diff --git a/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py b/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py index 099ef86050..0cc7b3e27d 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py +++ b/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py @@ -8,6 +8,12 @@ import urlparse import xmlrpclib from distutils.version import LooseVersion +try: + import requests +except ImportError: + print("""The requests library is not installed. To continue: + pip install requests""") + from optparse import make_option from django.core.management.base import NoArgsCommand @@ -166,13 +172,25 @@ class Command(NoArgsCommand): } if self.github_api_token: headers["Authorization"] = "token {0}".format(self.github_api_token) - user, repo = urlparse.urlparse(req_url).path.split("#")[0].strip("/").rstrip("/").split("/") + try: + user, repo = urlparse.urlparse(req_url).path.split("#")[0].strip("/").rstrip("/").split("/") + except (ValueError, IndexError) as e: + print("\nFailed to parse %r: %s\n" % (req_url, e)) + continue + + try: + #test_auth = self._urlopen_as_json("https://api.github.com/django/", headers=headers) + test_auth = requests.get("https://api.github.com/django/", headers=headers).json() + except urllib2.HTTPError as e: + print("\n%s\n" % str(e)) + return - test_auth = self._urlopen_as_json("https://api.github.com/django/", headers=headers) if "message" in test_auth and test_auth["message"] == "Bad credentials": - sys.exit("\nGithub API: Bad credentials. Aborting!\n") + print("\nGithub API: Bad credentials. Aborting!\n") + return elif "message" in test_auth and test_auth["message"].startswith("API Rate Limit Exceeded"): - sys.exit("\nGithub API: Rate Limit Exceeded. Aborting!\n") + print("\nGithub API: Rate Limit Exceeded. Aborting!\n") + return if ".git" in repo: repo_name, frozen_commit_full = repo.split(".git") @@ -186,11 +204,14 @@ class Command(NoArgsCommand): if frozen_commit_sha: branch_url = "https://api.github.com/repos/{0}/{1}/branches".format(user, repo_name) - branch_data = self._urlopen_as_json(branch_url, headers=headers) - - frozen_commit_url = "https://api.github.com/repos/{0}/{1}/commits/{2}" \ - .format(user, repo_name, frozen_commit_sha) - frozen_commit_data = self._urlopen_as_json(frozen_commit_url, headers=headers) + #branch_data = self._urlopen_as_json(branch_url, headers=headers) + branch_data = requests.get(branch_url, headers=headers).json() + + frozen_commit_url = "https://api.github.com/repos/{0}/{1}/commits/{2}".format( + user, repo_name, frozen_commit_sha + ) + #frozen_commit_data = self._urlopen_as_json(frozen_commit_url, headers=headers) + frozen_commit_data = requests.get(frozen_commit_url, headers=headers).json() if "message" in frozen_commit_data and frozen_commit_data["message"] == "Not Found": msg = "{0} not found in {1}. Repo may be private.".format(frozen_commit_sha[:10], name) diff --git a/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py b/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py index 695c050bd8..000453aa81 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py +++ b/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py @@ -1,5 +1,8 @@ from django.core.management.base import BaseCommand, CommandError -from django.contrib.auth.models import User +try: + from django.contrib.auth import get_user_model # Django 1.5 +except ImportError: + from django_extensions.future_1_5 import get_user_model from django.contrib.sessions.models import Session import re @@ -38,6 +41,7 @@ class Command(BaseCommand): print('No user associated with session') return print("User id: %s" % uid) + User = get_user_model() try: user = User.objects.get(pk=uid) except User.DoesNotExist: diff --git a/awx/lib/site-packages/django_extensions/management/commands/reset_db.py b/awx/lib/site-packages/django_extensions/management/commands/reset_db.py index 71d59818f6..e3f3fad3b6 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/reset_db.py +++ b/awx/lib/site-packages/django_extensions/management/commands/reset_db.py @@ -92,10 +92,10 @@ Type 'yes' to continue, or 'no' to cancel: """ % (settings.DATABASE_NAME,)) if password is None: password = settings.DATABASE_PASSWORD - if engine == 'sqlite3': + if engine in ('sqlite3', 'spatialite'): import os try: - logging.info("Unlinking sqlite3 database") + logging.info("Unlinking %s database" % engine) os.unlink(settings.DATABASE_NAME) except OSError: pass diff --git a/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py b/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py index 47cd87b917..7798a284e6 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py +++ b/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py @@ -116,6 +116,8 @@ class Command(BaseCommand): help='Specifies the directory from which to serve admin media.'), make_option('--prof-path', dest='prof_path', default='/tmp', help='Specifies the directory which to save profile information in.'), + make_option('--prof-file', dest='prof_file', default='{path}.{duration:06d}ms.{time}', + help='Set filename format, default if "{path}.{duration:06d}ms.{time}".'), make_option('--nomedia', action='store_true', dest='no_media', default=False, help='Do not profile MEDIA_URL and ADMIN_MEDIA_URL'), make_option('--use-cprofile', action='store_true', dest='use_cprofile', default=False, @@ -186,6 +188,11 @@ class Command(BaseCommand): raise SystemExit("Kcachegrind compatible output format required cProfile from Python 2.5") prof_path = options.get('prof_path', '/tmp') + prof_file = options.get('prof_file', '{path}.{duration:06d}ms.{time}') + if not prof_file.format(path='1', duration=2, time=3): + prof_file = '{path}.{duration:06d}ms.{time}' + print("Filename format is wrong. Default format used: '{path}.{duration:06d}ms.{time}'.") + def get_exclude_paths(): exclude_paths = [] media_url = getattr(settings, 'MEDIA_URL', None) @@ -225,8 +232,8 @@ class Command(BaseCommand): kg.output(open(profname, 'w')) elif USE_CPROFILE: prof.dump_stats(profname) - profname2 = "%s.%06dms.%d.prof" % (path_name, elapms, time.time()) - profname2 = os.path.join(prof_path, profname2) + profname2 = prof_file.format(path=path_name, duration=int(elapms), time=int(time.time())) + profname2 = os.path.join(prof_path, "%s.prof" % profname2) if not USE_CPROFILE: prof.close() os.rename(profname, profname2) @@ -278,4 +285,3 @@ class Command(BaseCommand): autoreload.main(inner_run) else: inner_run() - diff --git a/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py b/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py index 8899a6b687..678e8eb91b 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py +++ b/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py @@ -3,15 +3,31 @@ from django.core.management.base import BaseCommand, CommandError from django_extensions.management.utils import setup_logger, RedirectHandler from optparse import make_option import os +import re +import socket import sys import time try: - from django.contrib.staticfiles.handlers import StaticFilesHandler - USE_STATICFILES = 'django.contrib.staticfiles' in settings.INSTALLED_APPS + if 'django.contrib.staticfiles' in settings.INSTALLED_APPS: + from django.contrib.staticfiles.handlers import StaticFilesHandler + USE_STATICFILES = True + elif 'staticfiles' in settings.INSTALLED_APPS: + from staticfiles.handlers import StaticFilesHandler # noqa + USE_STATICFILES = True + else: + USE_STATICFILES = False except ImportError: USE_STATICFILES = False +naiveip_re = re.compile(r"""^(?: +(?P<addr> + (?P<ipv4>\d{1,3}(?:\.\d{1,3}){3}) | # IPv4 address + (?P<ipv6>\[[a-fA-F0-9:]+\]) | # IPv6 address + (?P<fqdn>[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*) # FQDN +):)?(?P<port>\d+)$""", re.X) +DEFAULT_PORT = "8000" + import logging logger = logging.getLogger(__name__) @@ -20,6 +36,8 @@ from django_extensions.management.technical_response import null_technical_500_r class Command(BaseCommand): option_list = BaseCommand.option_list + ( + make_option('--ipv6', '-6', action='store_true', dest='use_ipv6', default=False, + help='Tells Django to use a IPv6 address.'), make_option('--noreload', action='store_false', dest='use_reloader', default=True, help='Tells Django to NOT use the auto-reloader.'), make_option('--browser', action='store_true', dest='open_browser', @@ -103,33 +121,51 @@ class Command(BaseCommand): from django.views import debug debug.technical_500_response = null_technical_500_response - if args: - raise CommandError('Usage is runserver %s' % self.args) + self.use_ipv6 = options.get('use_ipv6') + if self.use_ipv6 and not socket.has_ipv6: + raise CommandError('Your Python does not support IPv6.') + self._raw_ipv6 = False if not addrport: - addr = '' - port = '8000' - else: try: - addr, port = addrport.split(':') - except ValueError: - addr, port = '', addrport - if not addr: - addr = '127.0.0.1' - - if not port.isdigit(): - raise CommandError("%r is not a valid port number." % port) + addrport = settings.RUNSERVERPLUS_SERVER_ADDRESS_PORT + except AttributeError: + pass + if not addrport: + self.addr = '' + self.port = DEFAULT_PORT + else: + m = re.match(naiveip_re, addrport) + if m is None: + raise CommandError('"%s" is not a valid port number ' + 'or address:port pair.' % addrport) + self.addr, _ipv4, _ipv6, _fqdn, self.port = m.groups() + if not self.port.isdigit(): + raise CommandError("%r is not a valid port number." % + self.port) + if self.addr: + if _ipv6: + self.addr = self.addr[1:-1] + self.use_ipv6 = True + self._raw_ipv6 = True + elif self.use_ipv6 and not _fqdn: + raise CommandError('"%s" is not a valid IPv6 address.' + % self.addr) + if not self.addr: + self.addr = '::1' if self.use_ipv6 else '127.0.0.1' threaded = options.get('threaded', False) use_reloader = options.get('use_reloader', True) open_browser = options.get('open_browser', False) cert_path = options.get("cert_path") quit_command = (sys.platform == 'win32') and 'CTRL-BREAK' or 'CONTROL-C' + bind_url = "http://%s:%s/" % ( + self.addr if not self._raw_ipv6 else '[%s]' % self.addr, self.port) def inner_run(): print("Validating models...") self.validate(display_num_errors=True) print("\nDjango version %s, using settings %r" % (django.get_version(), settings.SETTINGS_MODULE)) - print("Development server is running at http://%s:%s/" % (addr, port)) + print("Development server is running at %s" % (bind_url,)) print("Using the Werkzeug debugger (http://werkzeug.pocoo.org/)") print("Quit the server with %s." % quit_command) path = options.get('admin_media_path', '') @@ -149,8 +185,7 @@ class Command(BaseCommand): handler = StaticFilesHandler(handler) if open_browser: import webbrowser - url = "http://%s:%s/" % (addr, port) - webbrowser.open(url) + webbrowser.open(bind_url) if cert_path: """ OpenSSL is needed for SSL support. @@ -189,8 +224,8 @@ class Command(BaseCommand): else: ssl_context = None run_simple( - addr, - int(port), + self.addr, + int(self.port), DebuggedApplication(handler, True), use_reloader=use_reloader, use_debugger=True, diff --git a/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py b/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py index df1e387a69..effeb476e2 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py +++ b/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py @@ -38,7 +38,11 @@ class Command(NoArgsCommand): if not settings.DEBUG: raise CommandError('Only available in debug mode') - from django.contrib.auth.models import User, Group + try: + from django.contrib.auth import get_user_model # Django 1.5 + except ImportError: + from django_extensions.future_1_5 import get_user_model + from django.contrib.auth.models import Group email = options.get('default_email', DEFAULT_FAKE_EMAIL) include_regexp = options.get('include_regexp', None) exclude_regexp = options.get('exclude_regexp', None) @@ -47,6 +51,7 @@ class Command(NoArgsCommand): no_admin = options.get('no_admin', False) no_staff = options.get('no_staff', False) + User = get_user_model() users = User.objects.all() if no_admin: users = users.exclude(is_superuser=True) diff --git a/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py b/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py index 3ce2437548..e502fbfdb4 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py +++ b/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py @@ -28,7 +28,11 @@ class Command(NoArgsCommand): if not settings.DEBUG: raise CommandError('Only available in debug mode') - from django.contrib.auth.models import User + try: + from django.contrib.auth import get_user_model # Django 1.5 + except ImportError: + from django_extensions.future_1_5 import get_user_model + if options.get('prompt_passwd', False): from getpass import getpass passwd = getpass('Password: ') @@ -37,6 +41,7 @@ class Command(NoArgsCommand): else: passwd = options.get('default_passwd', DEFAULT_FAKE_PASSWORD) + User = get_user_model() user = User() user.set_password(passwd) count = User.objects.all().update(password=user.password) diff --git a/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py b/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py index 19527ecdbe..8eacf127a6 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py +++ b/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py @@ -1,5 +1,6 @@ from optparse import make_option import sys +import socket import django from django.core.management.base import CommandError, BaseCommand @@ -57,6 +58,7 @@ The envisioned use case is something like this: dbuser = settings.DATABASE_USER dbpass = settings.DATABASE_PASSWORD dbhost = settings.DATABASE_HOST + dbclient = socket.gethostname() # django settings file tells you that localhost should be specified by leaving # the DATABASE_HOST blank @@ -69,7 +71,7 @@ The envisioned use case is something like this: """) print("CREATE DATABASE %s CHARACTER SET utf8 COLLATE utf8_bin;" % dbname) print("GRANT ALL PRIVILEGES ON %s.* to '%s'@'%s' identified by '%s';" % ( - dbname, dbuser, dbhost, dbpass + dbname, dbuser, dbclient, dbpass )) elif engine == 'postgresql_psycopg2': if options.get('drop'): diff --git a/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py b/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py index b2049a12c7..d89d218eef 100644 --- a/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py +++ b/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py @@ -228,7 +228,7 @@ class SQLDiff(object): def strip_parameters(self, field_type): if field_type and field_type != 'double precision': - return field_type.split(" ")[0].split("(")[0] + return field_type.split(" ")[0].split("(")[0].lower() return field_type def find_unique_missing_in_db(self, meta, table_indexes, table_name): @@ -289,14 +289,14 @@ class SQLDiff(object): continue description = db_fields[field.name] - model_type = self.strip_parameters(self.get_field_model_type(field)) - db_type = self.strip_parameters(self.get_field_db_type(description, field)) + model_type = self.get_field_model_type(field) + db_type = self.get_field_db_type(description, field) # use callback function if defined if func: model_type, db_type = func(field, description, model_type, db_type) - if not model_type == db_type: + if not self.strip_parameters(db_type) == self.strip_parameters(model_type): self.add_difference('field-type-differ', table_name, field.name, model_type, db_type) def find_field_parameter_differ(self, meta, table_description, table_name, func=None): diff --git a/awx/lib/site-packages/django_extensions/management/modelviz.py b/awx/lib/site-packages/django_extensions/management/modelviz.py index 83ac484798..1d18d16972 100644 --- a/awx/lib/site-packages/django_extensions/management/modelviz.py +++ b/awx/lib/site-packages/django_extensions/management/modelviz.py @@ -210,9 +210,9 @@ def generate_dot(app_labels, **kwargs): if skip_field(field): continue if isinstance(field, OneToOneField): - add_relation(field, '[arrowhead=none, arrowtail=none]') + add_relation(field, '[arrowhead=none, arrowtail=none, dir=both]') elif isinstance(field, ForeignKey): - add_relation(field, '[arrowhead=none, arrowtail=dot]') + add_relation(field, '[arrowhead=none, arrowtail=dot, dir=both]') for field in appmodel._meta.local_many_to_many: if skip_field(field): @@ -240,7 +240,7 @@ def generate_dot(app_labels, **kwargs): 'type': "inheritance", 'name': "inheritance", 'label': l, - 'arrows': '[arrowhead=empty, arrowtail=none]', + 'arrows': '[arrowhead=empty, arrowtail=none, dir=both]', 'needs_node': True } # TODO: seems as if abstract models aren't part of models.getModels, which is why they are printed by this without any attributes. diff --git a/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py b/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py index 4ffcc3ec17..d36ddb8405 100644 --- a/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py +++ b/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py @@ -69,7 +69,7 @@ class AutoSlugField(SlugField): raise ValueError("missing 'populate_from' argument") else: self._populate_from = populate_from - self.separator = kwargs.pop('separator', u'-') + self.separator = kwargs.pop('separator', six.u('-')) self.overwrite = kwargs.pop('overwrite', False) super(AutoSlugField, self).__init__(*args, **kwargs) diff --git a/awx/lib/site-packages/django_extensions/templatetags/widont.py b/awx/lib/site-packages/django_extensions/templatetags/widont.py index 687e5114ee..d42833f941 100644 --- a/awx/lib/site-packages/django_extensions/templatetags/widont.py +++ b/awx/lib/site-packages/django_extensions/templatetags/widont.py @@ -1,6 +1,7 @@ from django.template import Library from django.utils.encoding import force_unicode import re +import six register = Library() re_widont = re.compile(r'\s+(\S+\s*)$') @@ -24,7 +25,7 @@ def widont(value, count=1): NoEffect """ def replace(matchobj): - return u' %s' % matchobj.group(1) + return six.u(' %s' % matchobj.group(1)) for i in range(count): value = re_widont.sub(replace, force_unicode(value)) return value @@ -48,7 +49,7 @@ def widont_html(value): leading text <p>test me out</p> trailing text """ def replace(matchobj): - return u'%s %s%s' % matchobj.groups() + return six.u('%s %s%s' % matchobj.groups()) return re_widont_html.sub(replace, force_unicode(value)) register.filter(widont) diff --git a/awx/lib/site-packages/django_extensions/tests/json_field.py b/awx/lib/site-packages/django_extensions/tests/json_field.py index a73d1f6706..e9aed0ffc0 100644 --- a/awx/lib/site-packages/django_extensions/tests/json_field.py +++ b/awx/lib/site-packages/django_extensions/tests/json_field.py @@ -25,9 +25,13 @@ class JsonFieldTest(unittest.TestCase): def testCharFieldCreate(self): j = TestModel.objects.create(a=6, j_field=dict(foo='bar')) - self.assertEquals(j.a, 6) + self.assertEqual(j.a, 6) + + def testDefault(self): + j = TestModel.objects.create(a=1) + self.assertEqual(j.j_field, {}) def testEmptyList(self): j = TestModel.objects.create(a=6, j_field=[]) self.assertTrue(isinstance(j.j_field, list)) - self.assertEquals(j.j_field, []) + self.assertEqual(j.j_field, []) diff --git a/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py b/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py index b25dae0537..dd4b2190b2 100644 --- a/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py +++ b/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py @@ -52,7 +52,7 @@ class DumpScriptTests(TestCase): tmp_out = StringIO() call_command('dumpscript', 'tests', stdout=tmp_out) self.assertTrue('Mike' in tmp_out.getvalue()) # script should go to tmp_out - self.assertEquals(0, len(sys.stdout.getvalue())) # there should not be any output to sys.stdout + self.assertEqual(0, len(sys.stdout.getvalue())) # there should not be any output to sys.stdout tmp_out.close() #---------------------------------------------------------------------- @@ -65,7 +65,7 @@ class DumpScriptTests(TestCase): call_command('dumpscript', 'tests', stderr=tmp_err) self.assertTrue('Fred' in sys.stdout.getvalue()) # script should still go to stdout self.assertTrue('Name' in tmp_err.getvalue()) # error output should go to tmp_err - self.assertEquals(0, len(sys.stderr.getvalue())) # there should not be any output to sys.stderr + self.assertEqual(0, len(sys.stderr.getvalue())) # there should not be any output to sys.stderr tmp_err.close() #---------------------------------------------------------------------- diff --git a/awx/lib/site-packages/django_extensions/tests/utils.py b/awx/lib/site-packages/django_extensions/tests/utils.py index c91989f9f4..23935b66e0 100644 --- a/awx/lib/site-packages/django_extensions/tests/utils.py +++ b/awx/lib/site-packages/django_extensions/tests/utils.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import sys +import six from django.test import TestCase from django.utils.unittest import skipIf @@ -14,21 +15,21 @@ except ImportError: class TruncateLetterTests(TestCase): def test_truncate_more_than_text_length(self): - self.assertEquals(u"hello tests", truncate_letters("hello tests", 100)) + self.assertEqual(six.u("hello tests"), truncate_letters("hello tests", 100)) def test_truncate_text(self): - self.assertEquals(u"hello...", truncate_letters("hello tests", 5)) + self.assertEqual(six.u("hello..."), truncate_letters("hello tests", 5)) def test_truncate_with_range(self): for i in range(10, -1, -1): self.assertEqual( - u'hello tests'[:i] + '...', + six.u('hello tests'[:i]) + '...', truncate_letters("hello tests", i) ) def test_with_non_ascii_characters(self): - self.assertEquals( - u'\u5ce0 (\u3068\u3046\u3052 t\u014dg...', + self.assertEqual( + six.u('\u5ce0 (\u3068\u3046\u3052 t\u014dg...'), truncate_letters("峠 (とうげ tōge - mountain pass)", 10) ) @@ -37,7 +38,7 @@ class UUIDTests(TestCase): @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') def test_uuid3(self): # make a UUID using an MD5 hash of a namespace UUID and a name - self.assertEquals( + self.assertEqual( uuid.UUID('6fa459ea-ee8a-3ca4-894e-db77e160355e'), uuid.uuid3(uuid.NAMESPACE_DNS, 'python.org') ) @@ -45,7 +46,7 @@ class UUIDTests(TestCase): @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') def test_uuid5(self): # make a UUID using a SHA-1 hash of a namespace UUID and a name - self.assertEquals( + self.assertEqual( uuid.UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d'), uuid.uuid5(uuid.NAMESPACE_DNS, 'python.org') ) @@ -55,21 +56,21 @@ class UUIDTests(TestCase): # make a UUID from a string of hex digits (braces and hyphens ignored) x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}') # convert a UUID to a string of hex digits in standard form - self.assertEquals('00010203-0405-0607-0809-0a0b0c0d0e0f', str(x)) + self.assertEqual('00010203-0405-0607-0809-0a0b0c0d0e0f', str(x)) @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') def test_uuid_bytes(self): # make a UUID from a string of hex digits (braces and hyphens ignored) x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}') # get the raw 16 bytes of the UUID - self.assertEquals( + self.assertEqual( '\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f', x.bytes ) @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') def test_make_uuid_from_byte_string(self): - self.assertEquals( + self.assertEqual( uuid.UUID(bytes='\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f'), uuid.UUID('00010203-0405-0607-0809-0a0b0c0d0e0f') ) diff --git a/awx/lib/site-packages/django_extensions/tests/uuid_field.py b/awx/lib/site-packages/django_extensions/tests/uuid_field.py index 43416e5734..823ff2e2ba 100644 --- a/awx/lib/site-packages/django_extensions/tests/uuid_field.py +++ b/awx/lib/site-packages/django_extensions/tests/uuid_field.py @@ -1,3 +1,4 @@ +import six from django.conf import settings from django.core.management import call_command from django.db.models import loading @@ -36,20 +37,22 @@ class UUIDFieldTest(unittest.TestCase): settings.INSTALLED_APPS = self.old_installed_apps def testUUIDFieldCreate(self): - j = TestModel_field.objects.create(a=6, uuid_field=u'550e8400-e29b-41d4-a716-446655440000') - self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440000') + j = TestModel_field.objects.create(a=6, uuid_field=six.u('550e8400-e29b-41d4-a716-446655440000')) + self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440000')) def testUUIDField_pkCreate(self): - j = TestModel_pk.objects.create(uuid_field=u'550e8400-e29b-41d4-a716-446655440000') - self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440000') - self.assertEquals(j.pk, u'550e8400-e29b-41d4-a716-446655440000') + j = TestModel_pk.objects.create(uuid_field=six.u('550e8400-e29b-41d4-a716-446655440000')) + self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440000')) + self.assertEqual(j.pk, six.u('550e8400-e29b-41d4-a716-446655440000')) def testUUIDField_pkAgregateCreate(self): - j = TestAgregateModel.objects.create(a=6) - self.assertEquals(j.a, 6) + j = TestAgregateModel.objects.create(a=6, uuid_field=six.u('550e8400-e29b-41d4-a716-446655440001')) + self.assertEqual(j.a, 6) + self.assertIsInstance(j.pk, six.string_types) + self.assertEqual(len(j.pk), 36) def testUUIDFieldManyToManyCreate(self): - j = TestManyToManyModel.objects.create(uuid_field=u'550e8400-e29b-41d4-a716-446655440010') - self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440010') - self.assertEquals(j.pk, u'550e8400-e29b-41d4-a716-446655440010') + j = TestManyToManyModel.objects.create(uuid_field=six.u('550e8400-e29b-41d4-a716-446655440010')) + self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440010')) + self.assertEqual(j.pk, six.u('550e8400-e29b-41d4-a716-446655440010')) diff --git a/awx/lib/site-packages/django_extensions/utils/dia2django.py b/awx/lib/site-packages/django_extensions/utils/dia2django.py index 792529b10a..91a48631df 100644 --- a/awx/lib/site-packages/django_extensions/utils/dia2django.py +++ b/awx/lib/site-packages/django_extensions/utils/dia2django.py @@ -17,6 +17,7 @@ import sys import gzip from xml.dom.minidom import * # NOQA import re +import six #Type dictionary translation types SQL -> Django tsd = { @@ -75,7 +76,7 @@ def dia2django(archivo): datos = ppal.getElementsByTagName("dia:diagram")[0].getElementsByTagName("dia:layer")[0].getElementsByTagName("dia:object") clases = {} herit = [] - imports = u"" + imports = six.u("") for i in datos: #Look for the classes if i.getAttribute("type") == "UML - Class": @@ -165,7 +166,7 @@ def dia2django(archivo): a = i.getElementsByTagName("dia:string") for j in a: if len(j.childNodes[0].data[1:-1]): - imports += u"from %s.models import *" % j.childNodes[0].data[1:-1] + imports += six.u("from %s.models import *" % j.childNodes[0].data[1:-1]) addparentstofks(herit, clases) #Ordering the appearance of classes diff --git a/awx/lib/site-packages/django_extensions/utils/uuid.py b/awx/lib/site-packages/django_extensions/utils/uuid.py deleted file mode 100644 index 2684f22019..0000000000 --- a/awx/lib/site-packages/django_extensions/utils/uuid.py +++ /dev/null @@ -1,566 +0,0 @@ -# flake8:noqa -r"""UUID objects (universally unique identifiers) according to RFC 4122. - -This module provides immutable UUID objects (class UUID) and the functions -uuid1(), uuid3(), uuid4(), uuid5() for generating version 1, 3, 4, and 5 -UUIDs as specified in RFC 4122. - -If all you want is a unique ID, you should probably call uuid1() or uuid4(). -Note that uuid1() may compromise privacy since it creates a UUID containing -the computer's network address. uuid4() creates a random UUID. - -Typical usage: - - >>> import uuid - - # make a UUID based on the host ID and current time - >>> uuid.uuid1() - UUID('a8098c1a-f86e-11da-bd1a-00112444be1e') - - # make a UUID using an MD5 hash of a namespace UUID and a name - >>> uuid.uuid3(uuid.NAMESPACE_DNS, 'python.org') - UUID('6fa459ea-ee8a-3ca4-894e-db77e160355e') - - # make a random UUID - >>> uuid.uuid4() - UUID('16fd2706-8baf-433b-82eb-8c7fada847da') - - # make a UUID using a SHA-1 hash of a namespace UUID and a name - >>> uuid.uuid5(uuid.NAMESPACE_DNS, 'python.org') - UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d') - - # make a UUID from a string of hex digits (braces and hyphens ignored) - >>> x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}') - - # convert a UUID to a string of hex digits in standard form - >>> str(x) - '00010203-0405-0607-0809-0a0b0c0d0e0f' - - # get the raw 16 bytes of the UUID - >>> x.bytes - '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f' - - # make a UUID from a 16-byte string - >>> uuid.UUID(bytes=x.bytes) - UUID('00010203-0405-0607-0809-0a0b0c0d0e0f') -""" - -__author__ = 'Ka-Ping Yee <ping@zesty.ca>' - -RESERVED_NCS, RFC_4122, RESERVED_MICROSOFT, RESERVED_FUTURE = [ - 'reserved for NCS compatibility', 'specified in RFC 4122', - 'reserved for Microsoft compatibility', 'reserved for future definition' -] - - -class UUID(object): - """Instances of the UUID class represent UUIDs as specified in RFC 4122. - UUID objects are immutable, hashable, and usable as dictionary keys. - Converting a UUID to a string with str() yields something in the form - '12345678-1234-1234-1234-123456789abc'. The UUID constructor accepts - five possible forms: a similar string of hexadecimal digits, or a tuple - of six integer fields (with 32-bit, 16-bit, 16-bit, 8-bit, 8-bit, and - 48-bit values respectively) as an argument named 'fields', or a string - of 16 bytes (with all the integer fields in big-endian order) as an - argument named 'bytes', or a string of 16 bytes (with the first three - fields in little-endian order) as an argument named 'bytes_le', or a - single 128-bit integer as an argument named 'int'. - - UUIDs have these read-only attributes: - - bytes the UUID as a 16-byte string (containing the six - integer fields in big-endian byte order) - - bytes_le the UUID as a 16-byte string (with time_low, time_mid, - and time_hi_version in little-endian byte order) - - fields a tuple of the six integer fields of the UUID, - which are also available as six individual attributes - and two derived attributes: - - time_low the first 32 bits of the UUID - time_mid the next 16 bits of the UUID - time_hi_version the next 16 bits of the UUID - clock_seq_hi_variant the next 8 bits of the UUID - clock_seq_low the next 8 bits of the UUID - node the last 48 bits of the UUID - - time the 60-bit timestamp - clock_seq the 14-bit sequence number - - hex the UUID as a 32-character hexadecimal string - - int the UUID as a 128-bit integer - - urn the UUID as a URN as specified in RFC 4122 - - variant the UUID variant (one of the constants RESERVED_NCS, - RFC_4122, RESERVED_MICROSOFT, or RESERVED_FUTURE) - - version the UUID version number (1 through 5, meaningful only - when the variant is RFC_4122) - """ - - def __init__(self, hex=None, bytes=None, bytes_le=None, fields=None, int=None, version=None): - r"""Create a UUID from either a string of 32 hexadecimal digits, - a string of 16 bytes as the 'bytes' argument, a string of 16 bytes - in little-endian order as the 'bytes_le' argument, a tuple of six - integers (32-bit time_low, 16-bit time_mid, 16-bit time_hi_version, - 8-bit clock_seq_hi_variant, 8-bit clock_seq_low, 48-bit node) as - the 'fields' argument, or a single 128-bit integer as the 'int' - argument. When a string of hex digits is given, curly braces, - hyphens, and a URN prefix are all optional. For example, these - expressions all yield the same UUID: - - UUID('{12345678-1234-5678-1234-567812345678}') - UUID('12345678123456781234567812345678') - UUID('urn:uuid:12345678-1234-5678-1234-567812345678') - UUID(bytes='\x12\x34\x56\x78'*4) - UUID(bytes_le='\x78\x56\x34\x12\x34\x12\x78\x56' + - '\x12\x34\x56\x78\x12\x34\x56\x78') - UUID(fields=(0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678)) - UUID(int=0x12345678123456781234567812345678) - - Exactly one of 'hex', 'bytes', 'bytes_le', 'fields', or 'int' must - be given. The 'version' argument is optional; if given, the resulting - UUID will have its variant and version set according to RFC 4122, - overriding the given 'hex', 'bytes', 'bytes_le', 'fields', or 'int'. - """ - - if [hex, bytes, bytes_le, fields, int].count(None) != 4: - raise TypeError('need one of hex, bytes, bytes_le, fields, or int') - if hex is not None: - hex = hex.replace('urn:', '').replace('uuid:', '') - hex = hex.strip('{}').replace('-', '') - if len(hex) != 32: - raise ValueError('badly formed hexadecimal UUID string') - int = long(hex, 16) - if bytes_le is not None: - if len(bytes_le) != 16: - raise ValueError('bytes_le is not a 16-char string') - bytes = (bytes_le[3] + bytes_le[2] + bytes_le[1] + bytes_le[0] + - bytes_le[5] + bytes_le[4] + bytes_le[7] + bytes_le[6] + - bytes_le[8:]) - if bytes is not None: - if len(bytes) != 16: - raise ValueError('bytes is not a 16-char string') - int = long(('%02x' * 16) % tuple(map(ord, bytes)), 16) - if fields is not None: - if len(fields) != 6: - raise ValueError('fields is not a 6-tuple') - (time_low, time_mid, time_hi_version, - clock_seq_hi_variant, clock_seq_low, node) = fields - if not 0 <= time_low < 1 << 32L: - raise ValueError('field 1 out of range (need a 32-bit value)') - if not 0 <= time_mid < 1 << 16L: - raise ValueError('field 2 out of range (need a 16-bit value)') - if not 0 <= time_hi_version < 1 << 16L: - raise ValueError('field 3 out of range (need a 16-bit value)') - if not 0 <= clock_seq_hi_variant < 1 << 8L: - raise ValueError('field 4 out of range (need an 8-bit value)') - if not 0 <= clock_seq_low < 1 << 8L: - raise ValueError('field 5 out of range (need an 8-bit value)') - if not 0 <= node < 1 << 48L: - raise ValueError('field 6 out of range (need a 48-bit value)') - clock_seq = (clock_seq_hi_variant << 8L) | clock_seq_low - int = ((time_low << 96L) | (time_mid << 80L) | - (time_hi_version << 64L) | (clock_seq << 48L) | node) - if int is not None: - if not 0 <= int < 1 << 128L: - raise ValueError('int is out of range (need a 128-bit value)') - if version is not None: - if not 1 <= version <= 5: - raise ValueError('illegal version number') - # Set the variant to RFC 4122. - int &= ~(0xc000 << 48L) - int |= 0x8000 << 48L - # Set the version number. - int &= ~(0xf000 << 64L) - int |= version << 76L - self.__dict__['int'] = int - - def __cmp__(self, other): - if isinstance(other, UUID): - return cmp(self.int, other.int) - return NotImplemented - - def __hash__(self): - return hash(self.int) - - def __int__(self): - return self.int - - def __repr__(self): - return 'UUID(%r)' % str(self) - - def __setattr__(self, name, value): - raise TypeError('UUID objects are immutable') - - def __str__(self): - hex = '%032x' % self.int - return '%s-%s-%s-%s-%s' % ( - hex[:8], hex[8:12], hex[12:16], hex[16:20], hex[20:]) - - def get_bytes(self): - bytes = '' - for shift in range(0, 128, 8): - bytes = chr((self.int >> shift) & 0xff) + bytes - return bytes - - bytes = property(get_bytes) - - def get_bytes_le(self): - bytes = self.bytes - return (bytes[3] + bytes[2] + bytes[1] + bytes[0] + - bytes[5] + bytes[4] + bytes[7] + bytes[6] + bytes[8:]) - - bytes_le = property(get_bytes_le) - - def get_fields(self): - return (self.time_low, self.time_mid, self.time_hi_version, - self.clock_seq_hi_variant, self.clock_seq_low, self.node) - - fields = property(get_fields) - - def get_time_low(self): - return self.int >> 96L - - time_low = property(get_time_low) - - def get_time_mid(self): - return (self.int >> 80L) & 0xffff - - time_mid = property(get_time_mid) - - def get_time_hi_version(self): - return (self.int >> 64L) & 0xffff - - time_hi_version = property(get_time_hi_version) - - def get_clock_seq_hi_variant(self): - return (self.int >> 56L) & 0xff - - clock_seq_hi_variant = property(get_clock_seq_hi_variant) - - def get_clock_seq_low(self): - return (self.int >> 48L) & 0xff - - clock_seq_low = property(get_clock_seq_low) - - def get_time(self): - return (((self.time_hi_version & 0x0fffL) << 48L) | - (self.time_mid << 32L) | self.time_low) - - time = property(get_time) - - def get_clock_seq(self): - return (((self.clock_seq_hi_variant & 0x3fL) << 8L) | - self.clock_seq_low) - - clock_seq = property(get_clock_seq) - - def get_node(self): - return self.int & 0xffffffffffff - - node = property(get_node) - - def get_hex(self): - return '%032x' % self.int - - hex = property(get_hex) - - def get_urn(self): - return 'urn:uuid:' + str(self) - - urn = property(get_urn) - - def get_variant(self): - if not self.int & (0x8000 << 48L): - return RESERVED_NCS - elif not self.int & (0x4000 << 48L): - return RFC_4122 - elif not self.int & (0x2000 << 48L): - return RESERVED_MICROSOFT - else: - return RESERVED_FUTURE - - variant = property(get_variant) - - def get_version(self): - # The version bits are only meaningful for RFC 4122 UUIDs. - if self.variant == RFC_4122: - return int((self.int >> 76L) & 0xf) - - version = property(get_version) - - -def _find_mac(command, args, hw_identifiers, get_index): - import os - for dir in ['', '/sbin/', '/usr/sbin']: - executable = os.path.join(dir, command) - if not os.path.exists(executable): - continue - - try: - # LC_ALL to get English output, 2>/dev/null to - # prevent output on stderr - cmd = 'LC_ALL=C %s %s 2>/dev/null' % (executable, args) - pipe = os.popen(cmd) - except IOError: - continue - - for line in pipe: - words = line.lower().split() - for i in range(len(words)): - if words[i] in hw_identifiers: - return int(words[get_index(i)].replace(':', ''), 16) - return None - - -def _ifconfig_getnode(): - """Get the hardware address on Unix by running ifconfig.""" - - # This works on Linux ('' or '-a'), Tru64 ('-av'), but not all Unixes. - for args in ('', '-a', '-av'): - mac = _find_mac('ifconfig', args, ['hwaddr', 'ether'], lambda i: i + 1) - if mac: - return mac - - import socket - ip_addr = socket.gethostbyname(socket.gethostname()) - - # Try getting the MAC addr from arp based on our IP address (Solaris). - mac = _find_mac('arp', '-an', [ip_addr], lambda i: -1) - if mac: - return mac - - # This might work on HP-UX. - mac = _find_mac('lanscan', '-ai', ['lan0'], lambda i: 0) - if mac: - return mac - - return None - - -def _ipconfig_getnode(): - """Get the hardware address on Windows by running ipconfig.exe.""" - import os - import re - dirs = ['', r'c:\windows\system32', r'c:\winnt\system32'] - try: - import ctypes - buffer = ctypes.create_string_buffer(300) - ctypes.windll.kernel32.GetSystemDirectoryA(buffer, 300) - dirs.insert(0, buffer.value.decode('mbcs')) - except: - pass - for dir in dirs: - try: - pipe = os.popen(os.path.join(dir, 'ipconfig') + ' /all') - except IOError: - continue - for line in pipe: - value = line.split(':')[-1].strip().lower() - if re.match('([0-9a-f][0-9a-f]-){5}[0-9a-f][0-9a-f]', value): - return int(value.replace('-', ''), 16) - - -def _netbios_getnode(): - """Get the hardware address on Windows using NetBIOS calls. - See http://support.microsoft.com/kb/118623 for details.""" - import win32wnet - import netbios - ncb = netbios.NCB() - ncb.Command = netbios.NCBENUM - ncb.Buffer = adapters = netbios.LANA_ENUM() - adapters._pack() - if win32wnet.Netbios(ncb) != 0: - return - adapters._unpack() - for i in range(adapters.length): - ncb.Reset() - ncb.Command = netbios.NCBRESET - ncb.Lana_num = ord(adapters.lana[i]) - if win32wnet.Netbios(ncb) != 0: - continue - ncb.Reset() - ncb.Command = netbios.NCBASTAT - ncb.Lana_num = ord(adapters.lana[i]) - ncb.Callname = '*'.ljust(16) - ncb.Buffer = status = netbios.ADAPTER_STATUS() - if win32wnet.Netbios(ncb) != 0: - continue - status._unpack() - bytes = map(ord, status.adapter_address) - return ((bytes[0] << 40L) + (bytes[1] << 32L) + (bytes[2] << 24L) + - (bytes[3] << 16L) + (bytes[4] << 8L) + bytes[5]) - -# Thanks to Thomas Heller for ctypes and for his help with its use here. - -# If ctypes is available, use it to find system routines for UUID generation. -_uuid_generate_random = _uuid_generate_time = _UuidCreate = None -try: - import ctypes - import ctypes.util - _buffer = ctypes.create_string_buffer(16) - - # The uuid_generate_* routines are provided by libuuid on at least - # Linux and FreeBSD, and provided by libc on Mac OS X. - for libname in ['uuid', 'c']: - try: - lib = ctypes.CDLL(ctypes.util.find_library(libname)) - except: - continue - if hasattr(lib, 'uuid_generate_random'): - _uuid_generate_random = lib.uuid_generate_random - if hasattr(lib, 'uuid_generate_time'): - _uuid_generate_time = lib.uuid_generate_time - - # On Windows prior to 2000, UuidCreate gives a UUID containing the - # hardware address. On Windows 2000 and later, UuidCreate makes a - # random UUID and UuidCreateSequential gives a UUID containing the - # hardware address. These routines are provided by the RPC runtime. - # NOTE: at least on Tim's WinXP Pro SP2 desktop box, while the last - # 6 bytes returned by UuidCreateSequential are fixed, they don't appear - # to bear any relationship to the MAC address of any network device - # on the box. - try: - lib = ctypes.windll.rpcrt4 - except: - lib = None - _UuidCreate = getattr(lib, 'UuidCreateSequential', - getattr(lib, 'UuidCreate', None)) -except: - pass - - -def _unixdll_getnode(): - """Get the hardware address on Unix using ctypes.""" - _uuid_generate_time(_buffer) - return UUID(bytes=_buffer.raw).node - - -def _windll_getnode(): - """Get the hardware address on Windows using ctypes.""" - if _UuidCreate(_buffer) == 0: - return UUID(bytes=_buffer.raw).node - - -def _random_getnode(): - """Get a random node ID, with eighth bit set as suggested by RFC 4122.""" - import random - return random.randrange(0, 1 << 48L) | 0x010000000000L - -_node = None - - -def getnode(): - """Get the hardware address as a 48-bit positive integer. - - The first time this runs, it may launch a separate program, which could - be quite slow. If all attempts to obtain the hardware address fail, we - choose a random 48-bit number with its eighth bit set to 1 as recommended - in RFC 4122. - """ - - global _node - if _node is not None: - return _node - - import sys - if sys.platform == 'win32': - getters = [_windll_getnode, _netbios_getnode, _ipconfig_getnode] - else: - getters = [_unixdll_getnode, _ifconfig_getnode] - - for getter in getters + [_random_getnode]: - try: - _node = getter() - except: - continue - if _node is not None: - return _node - -_last_timestamp = None - - -def uuid1(node=None, clock_seq=None): - """Generate a UUID from a host ID, sequence number, and the current time. - If 'node' is not given, getnode() is used to obtain the hardware - address. If 'clock_seq' is given, it is used as the sequence number; - otherwise a random 14-bit sequence number is chosen.""" - - # When the system provides a version-1 UUID generator, use it (but don't - # use UuidCreate here because its UUIDs don't conform to RFC 4122). - if _uuid_generate_time and node is clock_seq is None: - _uuid_generate_time(_buffer) - return UUID(bytes=_buffer.raw) - - global _last_timestamp - import time - nanoseconds = int(time.time() * 1e9) - # 0x01b21dd213814000 is the number of 100-ns intervals between the - # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. - timestamp = int(nanoseconds / 100) + 0x01b21dd213814000L - if timestamp <= _last_timestamp: - timestamp = _last_timestamp + 1 - _last_timestamp = timestamp - if clock_seq is None: - import random - clock_seq = random.randrange(1 << 14L) # instead of stable storage - time_low = timestamp & 0xffffffffL - time_mid = (timestamp >> 32L) & 0xffffL - time_hi_version = (timestamp >> 48L) & 0x0fffL - clock_seq_low = clock_seq & 0xffL - clock_seq_hi_variant = (clock_seq >> 8L) & 0x3fL - if node is None: - node = getnode() - return UUID(fields=(time_low, time_mid, time_hi_version, - clock_seq_hi_variant, clock_seq_low, node), version=1) - - -def uuid3(namespace, name): - """Generate a UUID from the MD5 hash of a namespace UUID and a name.""" - try: - import hashlib - md5 = hashlib.md5 - except ImportError: - from md5 import md5 # NOQA - hash = md5(namespace.bytes + name).digest() - return UUID(bytes=hash[:16], version=3) - - -def uuid4(): - """Generate a random UUID.""" - - # When the system provides a version-4 UUID generator, use it. - if _uuid_generate_random: - _uuid_generate_random(_buffer) - return UUID(bytes=_buffer.raw) - - # Otherwise, get randomness from urandom or the 'random' module. - try: - import os - return UUID(bytes=os.urandom(16), version=4) - except: - import random - bytes = [chr(random.randrange(256)) for i in range(16)] - return UUID(bytes=bytes, version=4) - - -def uuid5(namespace, name): - """Generate a UUID from the SHA-1 hash of a namespace UUID and a name.""" - try: - import hashlib - sha = hashlib.sha1 - except ImportError: - from sha import sha # NOQA - hash = sha(namespace.bytes + name).digest() - return UUID(bytes=hash[:16], version=5) - -# The following standard UUIDs are for use with uuid3() or uuid5(). - -NAMESPACE_DNS = UUID('6ba7b810-9dad-11d1-80b4-00c04fd430c8') -NAMESPACE_URL = UUID('6ba7b811-9dad-11d1-80b4-00c04fd430c8') -NAMESPACE_OID = UUID('6ba7b812-9dad-11d1-80b4-00c04fd430c8') -NAMESPACE_X500 = UUID('6ba7b814-9dad-11d1-80b4-00c04fd430c8') diff --git a/awx/lib/site-packages/djcelery/__init__.py b/awx/lib/site-packages/djcelery/__init__.py index c281cfffdf..fea512138a 100644 --- a/awx/lib/site-packages/djcelery/__init__.py +++ b/awx/lib/site-packages/djcelery/__init__.py @@ -5,7 +5,7 @@ from __future__ import absolute_import import os -VERSION = (3, 0, 17) +VERSION = (3, 0, 21) __version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:]) __author__ = 'Ask Solem' __contact__ = 'ask@celeryproject.org' diff --git a/awx/lib/site-packages/djcelery/admin.py b/awx/lib/site-packages/djcelery/admin.py index 85d9d4aefe..774b92cef2 100644 --- a/awx/lib/site-packages/djcelery/admin.py +++ b/awx/lib/site-packages/djcelery/admin.py @@ -8,7 +8,6 @@ from django.contrib.admin import helpers from django.contrib.admin.views import main as main_views from django.shortcuts import render_to_response from django.template import RequestContext -from django.utils.encoding import force_unicode from django.utils.html import escape from django.utils.translation import ugettext_lazy as _ @@ -22,6 +21,11 @@ from .models import (TaskState, WorkerState, PeriodicTask, IntervalSchedule, CrontabSchedule) from .humanize import naturaldate +try: + from django.utils.encoding import force_text +except ImportError: + from django.utils.encoding import force_unicode as force_text + TASK_STATE_COLORS = {states.SUCCESS: 'green', states.FAILURE: 'red', @@ -175,7 +179,7 @@ class TaskMonitor(ModelMonitor): context = { 'title': _('Rate limit selection'), 'queryset': queryset, - 'object_name': force_unicode(opts.verbose_name), + 'object_name': force_text(opts.verbose_name), 'action_checkbox_name': helpers.ACTION_CHECKBOX_NAME, 'opts': opts, 'app_label': app_label, diff --git a/awx/lib/site-packages/djcelery/picklefield.py b/awx/lib/site-packages/djcelery/picklefield.py index 1f2845df1f..73933e5124 100644 --- a/awx/lib/site-packages/djcelery/picklefield.py +++ b/awx/lib/site-packages/djcelery/picklefield.py @@ -20,7 +20,11 @@ from zlib import compress, decompress from celery.utils.serialization import pickle from django.db import models -from django.utils.encoding import force_unicode + +try: + from django.utils.encoding import force_text +except ImportError: + from django.utils.encoding import force_unicode as force_text DEFAULT_PROTOCOL = 2 @@ -77,7 +81,7 @@ class PickledObjectField(models.Field): def get_db_prep_value(self, value, **kwargs): if value is not None and not isinstance(value, PickledObject): - return force_unicode(encode(value, self.compress, self.protocol)) + return force_text(encode(value, self.compress, self.protocol)) return value def value_to_string(self, obj): diff --git a/awx/lib/site-packages/djcelery/schedulers.py b/awx/lib/site-packages/djcelery/schedulers.py index 135fb5a7b5..3e58508b30 100644 --- a/awx/lib/site-packages/djcelery/schedulers.py +++ b/awx/lib/site-packages/djcelery/schedulers.py @@ -68,12 +68,12 @@ class ModelEntry(ScheduleEntry): def _default_now(self): return self.app.now() - def next(self): + def __next__(self): self.model.last_run_at = self.app.now() self.model.total_run_count += 1 self.model.no_changes = True return self.__class__(self.model) - __next__ = next # for 2to3 + next = __next__ # for 2to3 def save(self): # Object may not be synchronized, so only diff --git a/awx/lib/site-packages/djcelery/snapshot.py b/awx/lib/site-packages/djcelery/snapshot.py index c0a112f4ab..cf2ea4dd55 100644 --- a/awx/lib/site-packages/djcelery/snapshot.py +++ b/awx/lib/site-packages/djcelery/snapshot.py @@ -10,7 +10,7 @@ from django.conf import settings from celery import states from celery.events.state import Task from celery.events.snapshot import Polaroid -from celery.utils.timeutils import maybe_iso8601, timezone +from celery.utils.timeutils import maybe_iso8601 from .models import WorkerState, TaskState from .utils import maybe_make_aware @@ -31,7 +31,7 @@ NOT_SAVED_ATTRIBUTES = frozenset(['name', 'args', 'kwargs', 'eta']) def aware_tstamp(secs): """Event timestamps uses the local timezone.""" - return timezone.to_local_fallback(datetime.fromtimestamp(secs)) + return maybe_make_aware(datetime.fromtimestamp(secs)) class Camera(Polaroid): diff --git a/awx/lib/site-packages/djcelery/utils.py b/awx/lib/site-packages/djcelery/utils.py index 02702ef390..fdfc0318ad 100644 --- a/awx/lib/site-packages/djcelery/utils.py +++ b/awx/lib/site-packages/djcelery/utils.py @@ -54,7 +54,8 @@ try: def make_aware(value): if getattr(settings, 'USE_TZ', False): # naive datetimes are assumed to be in UTC. - value = timezone.make_aware(value, timezone.utc) + if timezone.is_naive(value): + value = timezone.make_aware(value, timezone.utc) # then convert to the Django configured timezone. default_tz = timezone.get_default_timezone() value = timezone.localtime(value, default_tz) diff --git a/awx/lib/site-packages/kombu/__init__.py b/awx/lib/site-packages/kombu/__init__.py index 962eae85d0..56d511de34 100644 --- a/awx/lib/site-packages/kombu/__init__.py +++ b/awx/lib/site-packages/kombu/__init__.py @@ -1,7 +1,7 @@ """Messaging Framework for Python""" from __future__ import absolute_import -VERSION = (2, 5, 10) +VERSION = (2, 5, 14) __version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:]) __author__ = 'Ask Solem' __contact__ = 'ask@celeryproject.org' diff --git a/awx/lib/site-packages/kombu/abstract.py b/awx/lib/site-packages/kombu/abstract.py index f4218f02b4..a4f0448e0c 100644 --- a/awx/lib/site-packages/kombu/abstract.py +++ b/awx/lib/site-packages/kombu/abstract.py @@ -1,6 +1,6 @@ """ -kombu.compression -================= +kombu.abstract +============== Object utilities. diff --git a/awx/lib/site-packages/kombu/connection.py b/awx/lib/site-packages/kombu/connection.py index 9a6146dc9b..a7a6b15e78 100644 --- a/awx/lib/site-packages/kombu/connection.py +++ b/awx/lib/site-packages/kombu/connection.py @@ -151,8 +151,9 @@ class Connection(object): password=None, virtual_host=None, port=None, insist=False, ssl=False, transport=None, connect_timeout=5, transport_options=None, login_method=None, uri_prefix=None, - heartbeat=0, failover_strategy='round-robin', **kwargs): - alt = [] + heartbeat=0, failover_strategy='round-robin', + alternates=None, **kwargs): + alt = [] if alternates is None else alternates # have to spell the args out, just to get nice docstrings :( params = self._initial_params = { 'hostname': hostname, 'userid': userid, @@ -328,6 +329,29 @@ class Connection(object): self._debug('closed') self._closed = True + def collect(self, socket_timeout=None): + # amqp requires communication to close, we don't need that just + # to clear out references, Transport._collect can also be implemented + # by other transports that want fast after fork + try: + gc_transport = self._transport._collect + except AttributeError: + _timeo = socket.getdefaulttimeout() + socket.setdefaulttimeout(socket_timeout) + try: + self._close() + except socket.timeout: + pass + finally: + socket.setdefaulttimeout(_timeo) + else: + gc_transport(self._connection) + if self._transport: + self._transport.client = None + self._transport = None + self.declared_entities.clear() + self._connection = None + def release(self): """Close the connection (if open).""" self._close() @@ -522,12 +546,9 @@ class Connection(object): transport_cls = RESOLVE_ALIASES.get(transport_cls, transport_cls) D = self.transport.default_connection_params - if self.alt: - hostname = ";".join(self.alt) - else: - hostname = self.hostname or D.get('hostname') - if self.uri_prefix: - hostname = '%s+%s' % (self.uri_prefix, hostname) + hostname = self.hostname or D.get('hostname') + if self.uri_prefix: + hostname = '%s+%s' % (self.uri_prefix, hostname) info = (('hostname', hostname), ('userid', self.userid or D.get('userid')), @@ -542,6 +563,10 @@ class Connection(object): ('login_method', self.login_method or D.get('login_method')), ('uri_prefix', self.uri_prefix), ('heartbeat', self.heartbeat)) + + if self.alt: + info += (('alternates', self.alt),) + return info def info(self): @@ -910,6 +935,9 @@ class Resource(object): else: self.close_resource(resource) + def collect_resource(self, resource): + pass + def force_close_all(self): """Closes and removes all resources in the pool (also those in use). @@ -919,32 +947,27 @@ class Resource(object): """ dirty = self._dirty resource = self._resource - while 1: + while 1: # - acquired try: dres = dirty.pop() except KeyError: break try: - self.close_resource(dres) + self.collect_resource(dres) except AttributeError: # Issue #78 pass - - mutex = getattr(resource, 'mutex', None) - if mutex: - mutex.acquire() - try: - while 1: - try: - res = resource.queue.pop() - except IndexError: - break - try: - self.close_resource(res) - except AttributeError: - pass # Issue #78 - finally: - if mutex: # pragma: no cover - mutex.release() + while 1: # - available + # deque supports '.clear', but lists do not, so for that + # reason we use pop here, so that the underlying object can + # be any object supporting '.pop' and '.append'. + try: + res = resource.queue.pop() + except IndexError: + break + try: + self.collect_resource(res) + except AttributeError: + pass # Issue #78 if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover _orig_acquire = acquire @@ -993,6 +1016,9 @@ class ConnectionPool(Resource): def close_resource(self, resource): resource._close() + def collect_resource(self, resource, socket_timeout=0.1): + return resource.collect(socket_timeout) + @contextmanager def acquire_channel(self, block=False): with self.acquire(block=block) as connection: diff --git a/awx/lib/site-packages/kombu/entity.py b/awx/lib/site-packages/kombu/entity.py index 32b199c377..eaf242e69d 100644 --- a/awx/lib/site-packages/kombu/entity.py +++ b/awx/lib/site-packages/kombu/entity.py @@ -123,6 +123,7 @@ class Exchange(MaybeChannelBound): type = 'direct' durable = True auto_delete = False + passive = False delivery_mode = PERSISTENT_DELIVERY_MODE attrs = ( @@ -130,6 +131,7 @@ class Exchange(MaybeChannelBound): ('type', None), ('arguments', None), ('durable', bool), + ('passive', bool), ('auto_delete', bool), ('delivery_mode', lambda m: DELIVERY_MODES.get(m) or m), ) @@ -143,7 +145,7 @@ class Exchange(MaybeChannelBound): def __hash__(self): return hash('E|%s' % (self.name, )) - def declare(self, nowait=False, passive=False): + def declare(self, nowait=False, passive=None): """Declare the exchange. Creates the exchange on the broker. @@ -152,6 +154,7 @@ class Exchange(MaybeChannelBound): response will not be waited for. Default is :const:`False`. """ + passive = self.passive if passive is None else passive if self.name: return self.channel.exchange_declare( exchange=self.name, type=self.type, durable=self.durable, @@ -489,7 +492,7 @@ class Queue(MaybeChannelBound): self.exchange.declare(nowait) self.queue_declare(nowait, passive=False) - if self.exchange is not None: + if self.exchange and self.exchange.name: self.queue_bind(nowait) # - declare extra/multi-bindings. @@ -541,8 +544,8 @@ class Queue(MaybeChannelBound): Returns the message instance if a message was available, or :const:`None` otherwise. - :keyword no_ack: If set messages received does not have to - be acknowledged. + :keyword no_ack: If enabled the broker will automatically + ack messages. This method provides direct access to the messages in a queue using a synchronous dialogue, designed for @@ -575,8 +578,8 @@ class Queue(MaybeChannelBound): can use the same consumer tags. If this field is empty the server will generate a unique tag. - :keyword no_ack: If set messages received does not have to - be acknowledged. + :keyword no_ack: If enabled the broker will automatically ack + messages. :keyword nowait: Do not wait for a reply. diff --git a/awx/lib/site-packages/kombu/messaging.py b/awx/lib/site-packages/kombu/messaging.py index 3067fc7b34..59be5bd085 100644 --- a/awx/lib/site-packages/kombu/messaging.py +++ b/awx/lib/site-packages/kombu/messaging.py @@ -276,8 +276,12 @@ class Consumer(object): #: consume from. queues = None - #: Flag for message acknowledgment disabled/enabled. - #: Enabled by default. + #: Flag for automatic message acknowledgment. + #: If enabled the messages are automatically acknowledged by the + #: broker. This can increase performance but means that you + #: have no control of when the message is removed. + #: + #: Disabled by default. no_ack = None #: By default all entities will be declared at instantiation, if you @@ -399,6 +403,12 @@ class Consumer(object): pass def add_queue(self, queue): + """Add a queue to the list of queues to consume from. + + This will not start consuming from the queue, + for that you will have to call :meth:`consume` after. + + """ queue = queue(self.channel) if self.auto_declare: queue.declare() @@ -406,9 +416,26 @@ class Consumer(object): return queue def add_queue_from_dict(self, queue, **options): + """This method is deprecated. + + Instead please use:: + + consumer.add_queue(Queue.from_dict(d)) + + """ return self.add_queue(Queue.from_dict(queue, **options)) def consume(self, no_ack=None): + """Start consuming messages. + + Can be called multiple times, but note that while it + will consume from new queues added since the last call, + it will not cancel consuming from removed queues ( + use :meth:`cancel_by_queue`). + + :param no_ack: See :attr:`no_ack`. + + """ if self.queues: no_ack = self.no_ack if no_ack is None else no_ack @@ -441,10 +468,12 @@ class Consumer(object): self.channel.basic_cancel(tag) def consuming_from(self, queue): + """Returns :const:`True` if the consumer is currently + consuming from queue'.""" name = queue if isinstance(queue, Queue): name = queue.name - return any(q.name == name for q in self.queues) + return name in self._active_tags def purge(self): """Purge messages from all queues. diff --git a/awx/lib/site-packages/kombu/mixins.py b/awx/lib/site-packages/kombu/mixins.py index 82b9733e92..36f4817911 100644 --- a/awx/lib/site-packages/kombu/mixins.py +++ b/awx/lib/site-packages/kombu/mixins.py @@ -13,6 +13,7 @@ import socket from contextlib import contextmanager from functools import partial from itertools import count +from time import sleep from .common import ignore_errors from .messaging import Consumer @@ -158,13 +159,18 @@ class ConsumerMixin(object): def extra_context(self, connection, channel): yield - def run(self): + def run(self, _tokens=1): + restart_limit = self.restart_limit + errors = (self.connection.connection_errors + + self.connection.channel_errors) while not self.should_stop: try: - if self.restart_limit.can_consume(1): + if restart_limit.can_consume(_tokens): for _ in self.consume(limit=None): pass - except self.connection.connection_errors: + else: + sleep(restart_limit.expected_time(_tokens)) + except errors: warn('Connection to broker lost. ' 'Trying to re-establish the connection...') diff --git a/awx/lib/site-packages/kombu/pidbox.py b/awx/lib/site-packages/kombu/pidbox.py index cc29ce1a99..0b42c809bb 100644 --- a/awx/lib/site-packages/kombu/pidbox.py +++ b/awx/lib/site-packages/kombu/pidbox.py @@ -20,6 +20,7 @@ from time import time from . import Exchange, Queue, Consumer, Producer from .clocks import LamportClock from .common import maybe_declare, oid_from +from .exceptions import InconsistencyError from .utils import cached_property, kwdict, uuid REPLY_QUEUE_EXPIRES = 10 @@ -215,9 +216,15 @@ class Mailbox(object): delivery_mode='transient', durable=False) producer = Producer(chan, auto_declare=False) - producer.publish(reply, exchange=exchange, routing_key=routing_key, - declare=[exchange], headers={ - 'ticket': ticket, 'clock': self.clock.forward()}) + try: + producer.publish( + reply, exchange=exchange, routing_key=routing_key, + declare=[exchange], headers={ + 'ticket': ticket, 'clock': self.clock.forward(), + }, + ) + except InconsistencyError: + pass # queue probably deleted and no one is expecting a reply. def _publish(self, type, arguments, destination=None, reply_ticket=None, channel=None, timeout=None): diff --git a/awx/lib/site-packages/kombu/tests/__init__.py b/awx/lib/site-packages/kombu/tests/__init__.py index 6a13a75074..fb9f21a210 100644 --- a/awx/lib/site-packages/kombu/tests/__init__.py +++ b/awx/lib/site-packages/kombu/tests/__init__.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import anyjson +import atexit import os import sys @@ -14,6 +15,24 @@ except ImportError: anyjson.force_implementation('simplejson') +def teardown(): + # Workaround for multiprocessing bug where logging + # is attempted after global already collected at shutdown. + cancelled = set() + try: + import multiprocessing.util + cancelled.add(multiprocessing.util._exit_function) + except (AttributeError, ImportError): + pass + + try: + atexit._exithandlers[:] = [ + e for e in atexit._exithandlers if e[0] not in cancelled + ] + except AttributeError: # pragma: no cover + pass # Py3 missing _exithandlers + + def find_distribution_modules(name=__name__, file=__file__): current_dist_depth = len(name.split('.')) - 1 current_dist = os.path.join(os.path.dirname(file), diff --git a/awx/lib/site-packages/kombu/tests/test_entities.py b/awx/lib/site-packages/kombu/tests/test_entities.py index 8e89f84ea8..f2d54756b9 100644 --- a/awx/lib/site-packages/kombu/tests/test_entities.py +++ b/awx/lib/site-packages/kombu/tests/test_entities.py @@ -130,6 +130,10 @@ class test_Exchange(TestCase): exc = Exchange('foo', 'direct', delivery_mode='transient') self.assertEqual(exc.delivery_mode, Exchange.TRANSIENT_DELIVERY_MODE) + def test_set_passive_mode(self): + exc = Exchange('foo', 'direct', passive=True) + self.assertTrue(exc.passive) + def test_set_persistent_delivery_mode(self): exc = Exchange('foo', 'direct', delivery_mode='persistent') self.assertEqual(exc.delivery_mode, Exchange.PERSISTENT_DELIVERY_MODE) diff --git a/awx/lib/site-packages/kombu/tests/test_messaging.py b/awx/lib/site-packages/kombu/tests/test_messaging.py index 0fb8a65751..477e4e9970 100644 --- a/awx/lib/site-packages/kombu/tests/test_messaging.py +++ b/awx/lib/site-packages/kombu/tests/test_messaging.py @@ -258,9 +258,13 @@ class test_Consumer(TestCase): def test_consuming_from(self): consumer = self.connection.Consumer() - consumer.queues[:] = [Queue('a'), Queue('b')] + consumer.queues[:] = [Queue('a'), Queue('b'), Queue('d')] + consumer._active_tags = {'a': 1, 'b': 2} + self.assertFalse(consumer.consuming_from(Queue('c'))) self.assertFalse(consumer.consuming_from('c')) + self.assertFalse(consumer.consuming_from(Queue('d'))) + self.assertFalse(consumer.consuming_from('d')) self.assertTrue(consumer.consuming_from(Queue('a'))) self.assertTrue(consumer.consuming_from(Queue('b'))) self.assertTrue(consumer.consuming_from('b')) diff --git a/awx/lib/site-packages/kombu/tests/test_serialization.py b/awx/lib/site-packages/kombu/tests/test_serialization.py index de7ed34312..d87a37f348 100644 --- a/awx/lib/site-packages/kombu/tests/test_serialization.py +++ b/awx/lib/site-packages/kombu/tests/test_serialization.py @@ -5,6 +5,8 @@ from __future__ import with_statement import sys +from base64 import b64decode + from kombu.serialization import (registry, register, SerializerNotInstalled, raw_encode, register_yaml, register_msgpack, decode, bytes_t, pickle, pickle_protocol, @@ -53,16 +55,13 @@ unicode: "Th\\xE9 quick brown fox jumps over th\\xE9 lazy dog" msgpack_py_data = dict(py_data) -# msgpack only supports tuples -msgpack_py_data['list'] = tuple(msgpack_py_data['list']) # Unicode chars are lost in transmit :( msgpack_py_data['unicode'] = 'Th quick brown fox jumps over th lazy dog' -msgpack_data = """\ -\x85\xa3int\n\xa5float\xcb@\t!\xfbS\xc8\xd4\xf1\xa4list\ -\x94\xa6george\xa5jerry\xa6elaine\xa5cosmo\xa6string\xda\ -\x00+The quick brown fox jumps over the lazy dog\xa7unicode\ -\xda\x00)Th quick brown fox jumps over th lazy dog\ -""" +msgpack_data = b64decode("""\ +haNpbnQKpWZsb2F0y0AJIftTyNTxpGxpc3SUpmdlb3JnZaVqZXJyeaZlbGFpbmWlY29zbW+mc3Rya\ +W5n2gArVGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIHRoZSBsYXp5IGRvZ6d1bmljb2Rl2g\ +ApVGggcXVpY2sgYnJvd24gZm94IGp1bXBzIG92ZXIgdGggbGF6eSBkb2c=\ +""") def say(m): diff --git a/awx/lib/site-packages/kombu/tests/test_utils.py b/awx/lib/site-packages/kombu/tests/test_utils.py index 7090546c64..c718c9e59c 100644 --- a/awx/lib/site-packages/kombu/tests/test_utils.py +++ b/awx/lib/site-packages/kombu/tests/test_utils.py @@ -215,7 +215,7 @@ class test_retry_over_time(TestCase): self.myfun, self.Predicate, max_retries=1, errback=self.errback, interval_max=14, ) - self.assertEqual(self.index, 2) + self.assertEqual(self.index, 1) # no errback self.assertRaises( self.Predicate, utils.retry_over_time, @@ -230,7 +230,7 @@ class test_retry_over_time(TestCase): self.myfun, self.Predicate, max_retries=0, errback=self.errback, interval_max=14, ) - self.assertEqual(self.index, 1) + self.assertEqual(self.index, 0) class test_cached_property(TestCase): diff --git a/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py b/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py index 745c014254..6840e407f8 100644 --- a/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py +++ b/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py @@ -3,8 +3,10 @@ from __future__ import with_statement import sys +from functools import partial from mock import patch from nose import SkipTest +from itertools import count try: import amqp # noqa @@ -43,6 +45,7 @@ class test_Channel(TestCase): pass self.conn = Mock() + self.conn._get_free_channel_id.side_effect = partial(next, count(0)) self.conn.channels = {} self.channel = Channel(self.conn, 0) diff --git a/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py b/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py index b2942b671d..1bac68e6a6 100644 --- a/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py +++ b/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from mock import patch from nose import SkipTest -from kombu.utils.encoding import safe_str +from kombu.utils.encoding import bytes_t, safe_str, default_encoding from kombu.tests.utils import TestCase @@ -26,16 +26,16 @@ def clean_encoding(): class test_default_encoding(TestCase): - @patch('sys.getfilesystemencoding') - def test_default(self, getfilesystemencoding): - getfilesystemencoding.return_value = 'ascii' + @patch('sys.getdefaultencoding') + def test_default(self, getdefaultencoding): + getdefaultencoding.return_value = 'ascii' with clean_encoding() as encoding: enc = encoding.default_encoding() if sys.platform.startswith('java'): self.assertEqual(enc, 'utf-8') else: self.assertEqual(enc, 'ascii') - getfilesystemencoding.assert_called_with() + getdefaultencoding.assert_called_with() class test_encoding_utils(TestCase): @@ -60,16 +60,36 @@ class test_encoding_utils(TestCase): class test_safe_str(TestCase): - def test_when_str(self): + def setUp(self): + self._cencoding = patch('sys.getdefaultencoding') + self._encoding = self._cencoding.__enter__() + self._encoding.return_value = 'ascii' + + def tearDown(self): + self._cencoding.__exit__() + + def test_when_bytes(self): self.assertEqual(safe_str('foo'), 'foo') def test_when_unicode(self): - self.assertIsInstance(safe_str(u'foo'), str) + self.assertIsInstance(safe_str(u'foo'), bytes_t) + + def test_when_encoding_utf8(self): + with patch('sys.getdefaultencoding') as encoding: + encoding.return_value = 'utf-8' + self.assertEqual(default_encoding(), 'utf-8') + s = u'The quiæk fåx jømps øver the lazy dåg' + res = safe_str(s) + self.assertIsInstance(res, bytes_t) + self.assertGreater(len(res), len(s)) def test_when_containing_high_chars(self): - s = u'The quiæk fåx jømps øver the lazy dåg' - res = safe_str(s) - self.assertIsInstance(res, str) + with patch('sys.getdefaultencoding') as encoding: + encoding.return_value = 'ascii' + s = u'The quiæk fåx jømps øver the lazy dåg' + res = safe_str(s) + self.assertIsInstance(res, bytes_t) + self.assertEqual(len(s), len(res)) def test_when_not_string(self): o = object() diff --git a/awx/lib/site-packages/kombu/transport/librabbitmq.py b/awx/lib/site-packages/kombu/transport/librabbitmq.py index acfa33e701..0909709198 100644 --- a/awx/lib/site-packages/kombu/transport/librabbitmq.py +++ b/awx/lib/site-packages/kombu/transport/librabbitmq.py @@ -9,6 +9,7 @@ kombu.transport.librabbitmq """ from __future__ import absolute_import +import os import socket try: @@ -28,6 +29,10 @@ from . import base DEFAULT_PORT = 5672 +NO_SSL_ERROR = """\ +ssl not supported by librabbitmq, please use pyamqp:// or stunnel\ +""" + class Message(base.Message): @@ -98,22 +103,41 @@ class Transport(base.Transport): for name, default_value in self.default_connection_params.items(): if not getattr(conninfo, name, None): setattr(conninfo, name, default_value) - conn = self.Connection(host=conninfo.host, - userid=conninfo.userid, - password=conninfo.password, - virtual_host=conninfo.virtual_host, - login_method=conninfo.login_method, - insist=conninfo.insist, - ssl=conninfo.ssl, - connect_timeout=conninfo.connect_timeout) + if conninfo.ssl: + raise NotImplementedError(NO_SSL_ERROR) + opts = dict({ + 'host': conninfo.host, + 'userid': conninfo.userid, + 'password': conninfo.password, + 'virtual_host': conninfo.virtual_host, + 'login_method': conninfo.login_method, + 'insist': conninfo.insist, + 'ssl': conninfo.ssl, + 'connect_timeout': conninfo.connect_timeout, + }, **conninfo.transport_options or {}) + conn = self.Connection(**opts) conn.client = self.client self.client.drain_events = conn.drain_events return conn def close_connection(self, connection): """Close the AMQP broker connection.""" + self.client.drain_events = None connection.close() + def _collect(self, connection): + if connection is not None: + for channel in connection.channels.itervalues(): + channel.connection = None + try: + os.close(connection.fileno()) + except OSError: + pass + connection.channels.clear() + connection.callbacks.clear() + self.client.drain_events = None + self.client = None + def verify_connection(self, connection): return connection.connected diff --git a/awx/lib/site-packages/kombu/transport/mongodb.py b/awx/lib/site-packages/kombu/transport/mongodb.py index 1a09d1f42b..ff0ff3e811 100644 --- a/awx/lib/site-packages/kombu/transport/mongodb.py +++ b/awx/lib/site-packages/kombu/transport/mongodb.py @@ -101,57 +101,42 @@ class Channel(virtual.Channel): See mongodb uri documentation: http://www.mongodb.org/display/DOCS/Connections """ - conninfo = self.connection.client + client = self.connection.client + hostname = client.hostname or DEFAULT_HOST + authdb = dbname = client.virtual_host - dbname = None - hostname = None - - if not conninfo.hostname: - conninfo.hostname = DEFAULT_HOST - - for part in conninfo.hostname.split('/'): - if not hostname: - hostname = 'mongodb://' + part - continue - - dbname = part - if '?' in part: - # In case someone is passing options - # to the mongodb connection. Right now - # it is not permitted by kombu - dbname, options = part.split('?') - hostname += '/?' + options - - hostname = "%s/%s" % ( - hostname, dbname in [None, "/"] and "admin" or dbname, - ) - if not dbname or dbname == "/": + if dbname in ["/", None]: dbname = "kombu_default" + authdb = "admin" + if not client.userid: + hostname = hostname.replace('/' + client.virtual_host, '/') + else: + hostname = hostname.replace('/' + client.virtual_host, + '/' + authdb) + + mongo_uri = 'mongodb://' + hostname # At this point we expect the hostname to be something like # (considering replica set form too): # # mongodb://[username:password@]host1[:port1][,host2[:port2], # ...[,hostN[:portN]]][/[?options]] - mongoconn = Connection(host=hostname, ssl=conninfo.ssl) + mongoconn = Connection(host=mongo_uri, ssl=client.ssl) + database = getattr(mongoconn, dbname) + version = mongoconn.server_info()['version'] if tuple(map(int, version.split('.')[:2])) < (1, 3): raise NotImplementedError( 'Kombu requires MongoDB version 1.3+, but connected to %s' % ( version, )) - database = getattr(mongoconn, dbname) - - # This is done by the connection uri - # if conninfo.userid: - # database.authenticate(conninfo.userid, conninfo.password) self.db = database col = database.messages col.ensure_index([('queue', 1), ('_id', 1)], background=True) if 'messages.broadcast' not in database.collection_names(): - capsize = conninfo.transport_options.get( - 'capped_queue_size') or 100000 + capsize = (client.transport_options.get('capped_queue_size') + or 100000) database.create_collection('messages.broadcast', size=capsize, capped=True) diff --git a/awx/lib/site-packages/kombu/transport/pyamqp.py b/awx/lib/site-packages/kombu/transport/pyamqp.py index 8eb62dcccb..8b066e7b64 100644 --- a/awx/lib/site-packages/kombu/transport/pyamqp.py +++ b/awx/lib/site-packages/kombu/transport/pyamqp.py @@ -75,14 +75,17 @@ class Transport(base.Transport): channel_errors = (StdChannelError, ) + amqp.Connection.channel_errors nb_keep_draining = True - driver_name = "py-amqp" - driver_type = "amqp" + driver_name = 'py-amqp' + driver_type = 'amqp' supports_heartbeats = True supports_ev = True def __init__(self, client, **kwargs): self.client = client - self.default_port = kwargs.get("default_port") or self.default_port + self.default_port = kwargs.get('default_port') or self.default_port + + def driver_version(self): + return amqp.__version__ def create_channel(self, connection): return connection.channel() @@ -98,15 +101,18 @@ class Transport(base.Transport): setattr(conninfo, name, default_value) if conninfo.hostname == 'localhost': conninfo.hostname = '127.0.0.1' - conn = self.Connection(host=conninfo.host, - userid=conninfo.userid, - password=conninfo.password, - login_method=conninfo.login_method, - virtual_host=conninfo.virtual_host, - insist=conninfo.insist, - ssl=conninfo.ssl, - connect_timeout=conninfo.connect_timeout, - heartbeat=conninfo.heartbeat) + opts = dict({ + 'host': conninfo.host, + 'userid': conninfo.userid, + 'password': conninfo.password, + 'login_method': conninfo.login_method, + 'virtual_host': conninfo.virtual_host, + 'insist': conninfo.insist, + 'ssl': conninfo.ssl, + 'connect_timeout': conninfo.connect_timeout, + 'heartbeat': conninfo.heartbeat, + }, **conninfo.transport_options or {}) + conn = self.Connection(**opts) conn.client = self.client return conn diff --git a/awx/lib/site-packages/kombu/transport/redis.py b/awx/lib/site-packages/kombu/transport/redis.py index bf9e9f49ad..687b57c586 100644 --- a/awx/lib/site-packages/kombu/transport/redis.py +++ b/awx/lib/site-packages/kombu/transport/redis.py @@ -206,7 +206,7 @@ class MultiChannelPoller(object): for fd in self._chan_to_sock.itervalues(): try: self.poller.unregister(fd) - except KeyError: + except (KeyError, ValueError): pass self._channels.clear() self._fd_to_chan.clear() @@ -707,11 +707,12 @@ class Channel(virtual.Channel): return self._queue_cycle[0:active] def _rotate_cycle(self, used): - """ - Move most recently used queue to end of list - """ - index = self._queue_cycle.index(used) - self._queue_cycle.append(self._queue_cycle.pop(index)) + """Move most recently used queue to end of list.""" + cycle = self._queue_cycle + try: + cycle.append(cycle.pop(cycle.index(used))) + except ValueError: + pass def _get_response_error(self): from redis import exceptions diff --git a/awx/lib/site-packages/kombu/utils/__init__.py b/awx/lib/site-packages/kombu/utils/__init__.py index 532fb883b4..c52382641a 100644 --- a/awx/lib/site-packages/kombu/utils/__init__.py +++ b/awx/lib/site-packages/kombu/utils/__init__.py @@ -216,7 +216,7 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None, try: return fun(*args, **kwargs) except catch, exc: - if max_retries is not None and retries > max_retries: + if max_retries is not None and retries >= max_retries: raise if callback: callback() diff --git a/awx/lib/site-packages/kombu/utils/encoding.py b/awx/lib/site-packages/kombu/utils/encoding.py index efeeb32957..a4a1698938 100644 --- a/awx/lib/site-packages/kombu/utils/encoding.py +++ b/awx/lib/site-packages/kombu/utils/encoding.py @@ -22,7 +22,7 @@ if sys.platform.startswith('java'): # pragma: no cover else: def default_encoding(): # noqa - return sys.getfilesystemencoding() + return sys.getdefaultencoding() if is_py3k: # pragma: no cover diff --git a/awx/lib/site-packages/kombu/utils/eventio.py b/awx/lib/site-packages/kombu/utils/eventio.py index e9acc923e8..80b1501f9e 100644 --- a/awx/lib/site-packages/kombu/utils/eventio.py +++ b/awx/lib/site-packages/kombu/utils/eventio.py @@ -10,7 +10,7 @@ from __future__ import absolute_import import errno import socket -from select import select as _selectf +from select import select as _selectf, error as _selecterr try: from select import epoll @@ -53,6 +53,11 @@ READ = POLL_READ = 0x001 WRITE = POLL_WRITE = 0x004 ERR = POLL_ERR = 0x008 | 0x010 +try: + SELECT_BAD_FD = set((errno.EBADF, errno.WSAENOTSOCK)) +except AttributeError: + SELECT_BAD_FD = set((errno.EBADF,)) + class Poller(object): @@ -79,11 +84,9 @@ class _epoll(Poller): def unregister(self, fd): try: self._epoll.unregister(fd) - except socket.error: - pass - except ValueError: + except (socket.error, ValueError, KeyError): pass - except IOError, exc: + except (IOError, OSError), exc: if get_errno(exc) != errno.ENOENT: raise @@ -191,13 +194,31 @@ class _select(Poller): if events & READ: self._rfd.add(fd) + def _remove_bad(self): + for fd in self._rfd | self._wfd | self._efd: + try: + _selectf([fd], [], [], 0) + except (_selecterr, socket.error), exc: + if get_errno(exc) in SELECT_BAD_FD: + self.unregister(fd) + def unregister(self, fd): self._rfd.discard(fd) self._wfd.discard(fd) self._efd.discard(fd) def _poll(self, timeout): - read, write, error = _selectf(self._rfd, self._wfd, self._efd, timeout) + try: + read, write, error = _selectf( + self._rfd, self._wfd, self._efd, timeout, + ) + except (_selecterr, socket.error), exc: + if get_errno(exc) == errno.EINTR: + return + elif get_errno(exc) in SELECT_BAD_FD: + return self._remove_bad() + raise + events = {} for fd in read: if not isinstance(fd, int): diff --git a/awx/lib/site-packages/rest_framework/__init__.py b/awx/lib/site-packages/rest_framework/__init__.py index 0a21018634..087808e0b2 100644 --- a/awx/lib/site-packages/rest_framework/__init__.py +++ b/awx/lib/site-packages/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.5' +__version__ = '2.3.7' VERSION = __version__ # synonym diff --git a/awx/lib/site-packages/rest_framework/authentication.py b/awx/lib/site-packages/rest_framework/authentication.py index 9caca78894..cf001a24dd 100644 --- a/awx/lib/site-packages/rest_framework/authentication.py +++ b/awx/lib/site-packages/rest_framework/authentication.py @@ -3,14 +3,13 @@ Provides various authentication policies. """ from __future__ import unicode_literals import base64 -from datetime import datetime from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider +from rest_framework.compat import oauth2_provider, provider_now from rest_framework.authtoken.models import Token @@ -27,6 +26,12 @@ def get_authorization_header(request): return auth +class CSRFCheck(CsrfViewMiddleware): + def _reject(self, request, reason): + # Return the failure reason instead of an HttpResponse + return reason + + class BaseAuthentication(object): """ All authentication classes should extend BaseAuthentication. @@ -104,27 +109,27 @@ class SessionAuthentication(BaseAuthentication): """ # Get the underlying HttpRequest object - http_request = request._request - user = getattr(http_request, 'user', None) + request = request._request + user = getattr(request, 'user', None) # Unauthenticated, CSRF validation not required if not user or not user.is_active: return None - # Enforce CSRF validation for session based authentication. - class CSRFCheck(CsrfViewMiddleware): - def _reject(self, request, reason): - # Return the failure reason instead of an HttpResponse - return reason + self.enforce_csrf(request) + + # CSRF passed with authenticated user + return (user, None) - reason = CSRFCheck().process_view(http_request, None, (), {}) + def enforce_csrf(self, request): + """ + Enforce CSRF validation for session based authentication. + """ + reason = CSRFCheck().process_view(request, None, (), {}) if reason: # CSRF failed, bail with explicit error message raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) - # CSRF passed with authenticated user - return (user, None) - class TokenAuthentication(BaseAuthentication): """ @@ -230,8 +235,9 @@ class OAuthAuthentication(BaseAuthentication): try: consumer_key = oauth_request.get_parameter('oauth_consumer_key') consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) - except oauth_provider.store.InvalidConsumerError as err: - raise exceptions.AuthenticationFailed(err) + except oauth_provider.store.InvalidConsumerError: + msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key') + raise exceptions.AuthenticationFailed(msg) if consumer.status != oauth_provider.consts.ACCEPTED: msg = 'Invalid consumer key status: %s' % consumer.get_status_display() @@ -319,9 +325,9 @@ class OAuth2Authentication(BaseAuthentication): try: token = oauth2_provider.models.AccessToken.objects.select_related('user') - # TODO: Change to timezone aware datetime when oauth2_provider add - # support to it. - token = token.get(token=access_token, expires__gt=datetime.now()) + # provider_now switches to timezone aware datetime when + # the oauth2_provider version supports to it. + token = token.get(token=access_token, expires__gt=provider_now()) except oauth2_provider.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') diff --git a/awx/lib/site-packages/rest_framework/authtoken/admin.py b/awx/lib/site-packages/rest_framework/authtoken/admin.py new file mode 100644 index 0000000000..ec28eb1ca2 --- /dev/null +++ b/awx/lib/site-packages/rest_framework/authtoken/admin.py @@ -0,0 +1,11 @@ +from django.contrib import admin +from rest_framework.authtoken.models import Token + + +class TokenAdmin(admin.ModelAdmin): + list_display = ('key', 'user', 'created') + fields = ('user',) + ordering = ('-created',) + + +admin.site.register(Token, TokenAdmin) diff --git a/awx/lib/site-packages/rest_framework/authtoken/models.py b/awx/lib/site-packages/rest_framework/authtoken/models.py index 52c45ad11f..7601f5b791 100644 --- a/awx/lib/site-packages/rest_framework/authtoken/models.py +++ b/awx/lib/site-packages/rest_framework/authtoken/models.py @@ -1,7 +1,7 @@ import uuid import hmac from hashlib import sha1 -from rest_framework.compat import User +from rest_framework.compat import AUTH_USER_MODEL from django.conf import settings from django.db import models @@ -11,7 +11,7 @@ class Token(models.Model): The default authorization token model. """ key = models.CharField(max_length=40, primary_key=True) - user = models.OneToOneField(User, related_name='auth_token') + user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token') created = models.DateTimeField(auto_now_add=True) class Meta: diff --git a/awx/lib/site-packages/rest_framework/compat.py b/awx/lib/site-packages/rest_framework/compat.py index 76dc00526c..6f7447add0 100644 --- a/awx/lib/site-packages/rest_framework/compat.py +++ b/awx/lib/site-packages/rest_framework/compat.py @@ -2,11 +2,13 @@ The `compat` module provides support for backwards compatibility with older versions of django/python, and compatibility wrappers around optional packages. """ + # flake8: noqa from __future__ import unicode_literals import django from django.core.exceptions import ImproperlyConfigured +from django.conf import settings # Try to import six from Django, fallback to included `six`. try: @@ -33,6 +35,12 @@ except ImportError: from django.utils.encoding import force_unicode as force_text +# HttpResponseBase only exists from 1.5 onwards +try: + from django.http.response import HttpResponseBase +except ImportError: + from django.http import HttpResponse as HttpResponseBase + # django-filter is optional try: import django_filters @@ -76,16 +84,9 @@ def get_concrete_model(model_cls): # Django 1.5 add support for custom auth user model if django.VERSION >= (1, 5): - from django.conf import settings - if hasattr(settings, 'AUTH_USER_MODEL'): - User = settings.AUTH_USER_MODEL - else: - from django.contrib.auth.models import User + AUTH_USER_MODEL = settings.AUTH_USER_MODEL else: - try: - from django.contrib.auth.models import User - except ImportError: - raise ImportError("User model is not to be found.") + AUTH_USER_MODEL = 'auth.User' if django.VERSION >= (1, 5): @@ -435,6 +436,42 @@ except ImportError: return force_text(url) +# RequestFactory only provide `generic` from 1.5 onwards + +from django.test.client import RequestFactory as DjangoRequestFactory +from django.test.client import FakePayload +try: + # In 1.5 the test client uses force_bytes + from django.utils.encoding import force_bytes_or_smart_bytes +except ImportError: + # In 1.3 and 1.4 the test client just uses smart_str + from django.utils.encoding import smart_str as force_bytes_or_smart_bytes + + +class RequestFactory(DjangoRequestFactory): + def generic(self, method, path, + data='', content_type='application/octet-stream', **extra): + parsed = urlparse.urlparse(path) + data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + r = { + 'PATH_INFO': self._get_path(parsed), + 'QUERY_STRING': force_text(parsed[4]), + 'REQUEST_METHOD': str(method), + } + if data: + r.update({ + 'CONTENT_LENGTH': len(data), + 'CONTENT_TYPE': str(content_type), + 'wsgi.input': FakePayload(data), + }) + elif django.VERSION <= (1, 4): + # For 1.3 we need an empty WSGI payload + r.update({ + 'wsgi.input': FakePayload('') + }) + r.update(extra) + return self.request(**r) + # Markdown is optional try: import markdown @@ -489,12 +526,22 @@ try: from provider.oauth2 import forms as oauth2_provider_forms from provider import scope as oauth2_provider_scope from provider import constants as oauth2_constants + from provider import __version__ as provider_version + if provider_version in ('0.2.3', '0.2.4'): + # 0.2.3 and 0.2.4 are supported version that do not support + # timezone aware datetimes + import datetime + provider_now = datetime.datetime.now + else: + # Any other supported version does use timezone aware datetimes + from django.utils.timezone import now as provider_now except ImportError: oauth2_provider = None oauth2_provider_models = None oauth2_provider_forms = None oauth2_provider_scope = None oauth2_constants = None + provider_now = None # Handle lazy strings from django.utils.functional import Promise diff --git a/awx/lib/site-packages/rest_framework/exceptions.py b/awx/lib/site-packages/rest_framework/exceptions.py index 0c96ecdd52..425a721499 100644 --- a/awx/lib/site-packages/rest_framework/exceptions.py +++ b/awx/lib/site-packages/rest_framework/exceptions.py @@ -86,10 +86,3 @@ class Throttled(APIException): self.detail = format % (self.wait, self.wait != 1 and 's' or '') else: self.detail = detail or self.default_detail - - -class ConfigurationError(Exception): - """ - Indicates an internal server error. - """ - pass diff --git a/awx/lib/site-packages/rest_framework/fields.py b/awx/lib/site-packages/rest_framework/fields.py index 535aa2ac8e..add9d224d3 100644 --- a/awx/lib/site-packages/rest_framework/fields.py +++ b/awx/lib/site-packages/rest_framework/fields.py @@ -7,25 +7,24 @@ from __future__ import unicode_literals import copy import datetime -from decimal import Decimal, DecimalException import inspect import re import warnings +from decimal import Decimal, DecimalException +from django import forms from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings from django.db.models.fields import BLANK_CHOICE_DASH -from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from django.utils.datastructures import SortedDict from rest_framework import ISO_8601 -from rest_framework.compat import (timezone, parse_date, parse_datetime, - parse_time) -from rest_framework.compat import BytesIO -from rest_framework.compat import six -from rest_framework.compat import smart_text, force_text, is_non_str_iterable +from rest_framework.compat import ( + timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text, + force_text, is_non_str_iterable +) from rest_framework.settings import api_settings @@ -101,6 +100,19 @@ def humanize_strptime(format_string): return format_string +def strip_multiple_choice_msg(help_text): + """ + Remove the 'Hold down "control" ...' message that is Django enforces in + select multiple fields on ModelForms. (Required for 1.5 and earlier) + + See https://code.djangoproject.com/ticket/9321 + """ + multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') + multiple_choice_msg = force_text(multiple_choice_msg) + + return help_text.replace(multiple_choice_msg, '') + + class Field(object): read_only = True creation_counter = 0 @@ -123,7 +135,7 @@ class Field(object): self.label = smart_text(label) if help_text is not None: - self.help_text = smart_text(help_text) + self.help_text = strip_multiple_choice_msg(smart_text(help_text)) def initialize(self, parent, field_name): """ @@ -256,6 +268,12 @@ class WritableField(Field): widget = widget() self.widget = widget + def __deepcopy__(self, memo): + result = copy.copy(self) + memo[id(self)] = result + result.validators = self.validators[:] + return result + def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -331,9 +349,13 @@ class ModelField(WritableField): raise ValueError("ModelField requires 'model_field' kwarg") self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) + getattr(self.model_field, 'min_length', None)) self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) + getattr(self.model_field, 'max_length', None)) + self.min_value = kwargs.pop('min_value', + getattr(self.model_field, 'min_value', None)) + self.max_value = kwargs.pop('max_value', + getattr(self.model_field, 'max_value', None)) super(ModelField, self).__init__(*args, **kwargs) @@ -341,6 +363,10 @@ class ModelField(WritableField): self.validators.append(validators.MinLengthValidator(self.min_length)) if self.max_length is not None: self.validators.append(validators.MaxLengthValidator(self.max_length)) + if self.min_value is not None: + self.validators.append(validators.MinValueValidator(self.min_value)) + if self.max_value is not None: + self.validators.append(validators.MaxValueValidator(self.max_value)) def from_native(self, value): rel = getattr(self.model_field, "rel", None) @@ -428,13 +454,6 @@ class SlugField(CharField): def __init__(self, *args, **kwargs): super(SlugField, self).__init__(*args, **kwargs) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - #result.widget = copy.deepcopy(self.widget, memo) - result.validators = self.validators[:] - return result - class ChoiceField(WritableField): type_name = 'ChoiceField' @@ -493,7 +512,7 @@ class EmailField(CharField): form_field_class = forms.EmailField default_error_messages = { - 'invalid': _('Enter a valid e-mail address.'), + 'invalid': _('Enter a valid email address.'), } default_validators = [validators.validate_email] @@ -503,13 +522,6 @@ class EmailField(CharField): return None return ret.strip() - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - #result.widget = copy.deepcopy(self.widget, memo) - result.validators = self.validators[:] - return result - class RegexField(CharField): type_name = 'RegexField' @@ -534,12 +546,6 @@ class RegexField(CharField): regex = property(_get_regex, _set_regex) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - result.validators = self.validators[:] - return result - class DateField(WritableField): type_name = 'DateField' @@ -918,7 +924,7 @@ class ImageField(FileField): if f is None: return None - from compat import Image + from rest_framework.compat import Image assert Image is not None, 'PIL must be installed for ImageField support' # We need to get a file object for PIL. We might have a path or we might diff --git a/awx/lib/site-packages/rest_framework/filters.py b/awx/lib/site-packages/rest_framework/filters.py index c058bc715e..4079e1bd55 100644 --- a/awx/lib/site-packages/rest_framework/filters.py +++ b/awx/lib/site-packages/rest_framework/filters.py @@ -109,8 +109,7 @@ class OrderingFilter(BaseFilterBackend): def get_ordering(self, request): """ - Search terms are set by a ?search=... query parameter, - and may be comma and/or whitespace delimited. + Ordering is set by a comma delimited ?ordering=... query parameter. """ params = request.QUERY_PARAMS.get(self.ordering_param) if params: @@ -134,7 +133,7 @@ class OrderingFilter(BaseFilterBackend): ordering = self.remove_invalid_fields(queryset, ordering) if not ordering: - # Use 'ordering' attribtue by default + # Use 'ordering' attribute by default ordering = self.get_default_ordering(view) if ordering: diff --git a/awx/lib/site-packages/rest_framework/generics.py b/awx/lib/site-packages/rest_framework/generics.py index 9ccc789805..99e9782e21 100644 --- a/awx/lib/site-packages/rest_framework/generics.py +++ b/awx/lib/site-packages/rest_framework/generics.py @@ -212,7 +212,7 @@ class GenericAPIView(views.APIView): You may want to override this if you need to provide different serializations depending on the incoming request. - (Eg. admins get full serialization, others get basic serilization) + (Eg. admins get full serialization, others get basic serialization) """ serializer_class = self.serializer_class if serializer_class is not None: @@ -285,7 +285,7 @@ class GenericAPIView(views.APIView): ) filter_kwargs = {self.slug_field: slug} else: - raise exceptions.ConfigurationError( + raise ImproperlyConfigured( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'attribute on the view correctly.' % diff --git a/awx/lib/site-packages/rest_framework/parsers.py b/awx/lib/site-packages/rest_framework/parsers.py index 25be2e6abc..96bfac84a0 100644 --- a/awx/lib/site-packages/rest_framework/parsers.py +++ b/awx/lib/site-packages/rest_framework/parsers.py @@ -50,10 +50,7 @@ class JSONParser(BaseParser): def parse(self, stream, media_type=None, parser_context=None): """ - Returns a 2-tuple of `(data, files)`. - - `data` will be an object which is the parsed content of the response. - `files` will always be `None`. + Parses the incoming bytestream as JSON and returns the resulting data. """ parser_context = parser_context or {} encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) @@ -74,10 +71,7 @@ class YAMLParser(BaseParser): def parse(self, stream, media_type=None, parser_context=None): """ - Returns a 2-tuple of `(data, files)`. - - `data` will be an object which is the parsed content of the response. - `files` will always be `None`. + Parses the incoming bytestream as YAML and returns the resulting data. """ assert yaml, 'YAMLParser requires pyyaml to be installed' @@ -100,10 +94,8 @@ class FormParser(BaseParser): def parse(self, stream, media_type=None, parser_context=None): """ - Returns a 2-tuple of `(data, files)`. - - `data` will be a :class:`QueryDict` containing all the form parameters. - `files` will always be :const:`None`. + Parses the incoming bytestream as a URL encoded form, + and returns the resulting QueryDict. """ parser_context = parser_context or {} encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) @@ -120,7 +112,8 @@ class MultiPartParser(BaseParser): def parse(self, stream, media_type=None, parser_context=None): """ - Returns a DataAndFiles object. + Parses the incoming bytestream as a multipart encoded form, + and returns a DataAndFiles object. `.data` will be a `QueryDict` containing all the form parameters. `.files` will be a `QueryDict` containing all the form files. @@ -147,6 +140,9 @@ class XMLParser(BaseParser): media_type = 'application/xml' def parse(self, stream, media_type=None, parser_context=None): + """ + Parses the incoming bytestream as XML and returns the resulting data. + """ assert etree, 'XMLParser requires defusedxml to be installed' parser_context = parser_context or {} @@ -216,7 +212,8 @@ class FileUploadParser(BaseParser): def parse(self, stream, media_type=None, parser_context=None): """ - Returns a DataAndFiles object. + Treats the incoming bytestream as a raw file upload and returns + a `DateAndFiles` object. `.data` will be None (we expect request body to be a file content). `.files` will be a `QueryDict` containing one 'file' element. diff --git a/awx/lib/site-packages/rest_framework/permissions.py b/awx/lib/site-packages/rest_framework/permissions.py index 45fcfd6658..1036663e05 100644 --- a/awx/lib/site-packages/rest_framework/permissions.py +++ b/awx/lib/site-packages/rest_framework/permissions.py @@ -128,7 +128,7 @@ class DjangoModelPermissions(BasePermission): # Workaround to ensure DjangoModelPermissions are not applied # to the root view when using DefaultRouter. - if model_cls is None and getattr(view, '_ignore_model_permissions'): + if model_cls is None and getattr(view, '_ignore_model_permissions', False): return True assert model_cls, ('Cannot apply DjangoModelPermissions on a view that' diff --git a/awx/lib/site-packages/rest_framework/relations.py b/awx/lib/site-packages/rest_framework/relations.py index e3675b5124..edaf76d6ee 100644 --- a/awx/lib/site-packages/rest_framework/relations.py +++ b/awx/lib/site-packages/rest_framework/relations.py @@ -12,7 +12,7 @@ from django.db.models.fields import BLANK_CHOICE_DASH from django.forms import widgets from django.forms.models import ModelChoiceIterator from django.utils.translation import ugettext_lazy as _ -from rest_framework.fields import Field, WritableField, get_component +from rest_framework.fields import Field, WritableField, get_component, is_simple_callable from rest_framework.reverse import reverse from rest_framework.compat import urlparse from rest_framework.compat import smart_text @@ -144,7 +144,12 @@ class RelatedField(WritableField): return None if self.many: - return [self.to_native(item) for item in value.all()] + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] + else: + # Also support non-queryset iterables. + # This allows us to also support plain lists of related items. + return [self.to_native(item) for item in value] return self.to_native(value) def field_from_native(self, data, files, field_name, into): @@ -242,7 +247,12 @@ class PrimaryKeyRelatedField(RelatedField): queryset = get_component(queryset, component) # Forward relationship - return [self.to_native(item.pk) for item in queryset.all()] + if is_simple_callable(getattr(queryset, 'all', None)): + return [self.to_native(item.pk) for item in queryset.all()] + else: + # Also support non-queryset iterables. + # This allows us to also support plain lists of related items. + return [self.to_native(item.pk) for item in queryset] # To-one relationship try: diff --git a/awx/lib/site-packages/rest_framework/renderers.py b/awx/lib/site-packages/rest_framework/renderers.py index b2fe43eac2..3a03ca3324 100644 --- a/awx/lib/site-packages/rest_framework/renderers.py +++ b/awx/lib/site-packages/rest_framework/renderers.py @@ -11,14 +11,15 @@ from __future__ import unicode_literals import copy import json from django import forms +from django.core.exceptions import ImproperlyConfigured from django.http.multipartparser import parse_header from django.template import RequestContext, loader, Template +from django.test.client import encode_multipart from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO from rest_framework.compat import six from rest_framework.compat import smart_text from rest_framework.compat import yaml -from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders @@ -270,7 +271,7 @@ class TemplateHTMLRenderer(BaseRenderer): return [self.template_name] elif hasattr(view, 'get_template_names'): return view.get_template_names() - raise ConfigurationError('Returned a template response with no template_name') + raise ImproperlyConfigured('Returned a template response with no template_name') def get_exception_template(self, response): template_names = [name % {'status_code': response.status_code} @@ -571,3 +572,13 @@ class BrowsableAPIRenderer(BaseRenderer): response.status_code = status.HTTP_200_OK return ret + + +class MultiPartRenderer(BaseRenderer): + media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' + format = 'multipart' + charset = 'utf-8' + BOUNDARY = 'BoUnDaRyStRiNg' + + def render(self, data, accepted_media_type=None, renderer_context=None): + return encode_multipart(self.BOUNDARY, data) diff --git a/awx/lib/site-packages/rest_framework/request.py b/awx/lib/site-packages/rest_framework/request.py index 0d88ebc7e4..919716f49a 100644 --- a/awx/lib/site-packages/rest_framework/request.py +++ b/awx/lib/site-packages/rest_framework/request.py @@ -64,6 +64,20 @@ def clone_request(request, method): return ret +class ForcedAuthentication(object): + """ + This authentication class is used if the test client or request factory + forcibly authenticated the request. + """ + + def __init__(self, force_user, force_token): + self.force_user = force_user + self.force_token = force_token + + def authenticate(self, request): + return (self.force_user, self.force_token) + + class Request(object): """ Wrapper allowing to enhance a standard `HttpRequest` instance. @@ -98,6 +112,12 @@ class Request(object): self.parser_context['request'] = self self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET + force_user = getattr(request, '_force_auth_user', None) + force_token = getattr(request, '_force_auth_token', None) + if (force_user is not None or force_token is not None): + forced_auth = ForcedAuthentication(force_user, force_token) + self.authenticators = (forced_auth,) + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/awx/lib/site-packages/rest_framework/response.py b/awx/lib/site-packages/rest_framework/response.py index 3ee52ae01f..5877c8a3e9 100644 --- a/awx/lib/site-packages/rest_framework/response.py +++ b/awx/lib/site-packages/rest_framework/response.py @@ -12,7 +12,7 @@ from rest_framework.compat import six class Response(SimpleTemplateResponse): """ - An HttpResponse that allows it's data to be rendered into + An HttpResponse that allows its data to be rendered into arbitrary media types. """ diff --git a/awx/lib/site-packages/rest_framework/routers.py b/awx/lib/site-packages/rest_framework/routers.py index 6c5fd00483..930011d39e 100644 --- a/awx/lib/site-packages/rest_framework/routers.py +++ b/awx/lib/site-packages/rest_framework/routers.py @@ -15,10 +15,11 @@ For example, you might have a `urls.py` that looks something like this: """ from __future__ import unicode_literals +import itertools from collections import namedtuple +from django.core.exceptions import ImproperlyConfigured from rest_framework import views from rest_framework.compat import patterns, url -from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.urlpatterns import format_suffix_patterns @@ -39,6 +40,13 @@ def replace_methodname(format_string, methodname): return ret +def flatten(list_of_lists): + """ + Takes an iterable of iterables, returns a single iterable containing all items + """ + return itertools.chain(*list_of_lists) + + class BaseRouter(object): def __init__(self): self.registry = [] @@ -72,7 +80,7 @@ class SimpleRouter(BaseRouter): routes = [ # List route. Route( - url=r'^{prefix}/$', + url=r'^{prefix}{trailing_slash}$', mapping={ 'get': 'list', 'post': 'create' @@ -82,7 +90,7 @@ class SimpleRouter(BaseRouter): ), # Detail route. Route( - url=r'^{prefix}/{lookup}/$', + url=r'^{prefix}/{lookup}{trailing_slash}$', mapping={ 'get': 'retrieve', 'put': 'update', @@ -95,7 +103,7 @@ class SimpleRouter(BaseRouter): # Dynamically generated routes. # Generated using @action or @link decorators on methods of the viewset. Route( - url=r'^{prefix}/{lookup}/{methodname}/$', + url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$', mapping={ '{httpmethod}': '{methodname}', }, @@ -104,6 +112,10 @@ class SimpleRouter(BaseRouter): ), ] + def __init__(self, trailing_slash=True): + self.trailing_slash = trailing_slash and '/' or '' + super(SimpleRouter, self).__init__() + def get_default_base_name(self, viewset): """ If `base_name` is not specified, attempt to automatically determine @@ -114,7 +126,7 @@ class SimpleRouter(BaseRouter): if model_cls is None and queryset is not None: model_cls = queryset.model - assert model_cls, '`name` not argument not specified, and could ' \ + assert model_cls, '`base_name` argument not specified, and could ' \ 'not automatically determine the name from the viewset, as ' \ 'it does not have a `.model` or `.queryset` attribute.' @@ -127,12 +139,18 @@ class SimpleRouter(BaseRouter): Returns a list of the Route namedtuple. """ + known_actions = flatten([route.mapping.values() for route in self.routes]) + # Determine any `@action` or `@link` decorated methods on the viewset dynamic_routes = [] for methodname in dir(viewset): attr = getattr(viewset, methodname) httpmethods = getattr(attr, 'bind_to_methods', None) if httpmethods: + if methodname in known_actions: + raise ImproperlyConfigured('Cannot use @action or @link decorator on ' + 'method "%s" as it is an existing route' % methodname) + httpmethods = [method.lower() for method in httpmethods] dynamic_routes.append((httpmethods, methodname)) ret = [] @@ -193,7 +211,11 @@ class SimpleRouter(BaseRouter): continue # Build the url pattern - regex = route.url.format(prefix=prefix, lookup=lookup) + regex = route.url.format( + prefix=prefix, + lookup=lookup, + trailing_slash=self.trailing_slash + ) view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) @@ -208,6 +230,7 @@ class DefaultRouter(SimpleRouter): """ include_root_view = True include_format_suffixes = True + root_view_name = 'api-root' def get_api_root_view(self): """ @@ -237,7 +260,7 @@ class DefaultRouter(SimpleRouter): urls = [] if self.include_root_view: - root_url = url(r'^$', self.get_api_root_view(), name='api-root') + root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name) urls.append(root_url) default_urls = super(DefaultRouter, self).get_urls() diff --git a/awx/lib/site-packages/rest_framework/runtests/settings.py b/awx/lib/site-packages/rest_framework/runtests/settings.py index 9dd7b545e6..b3702d0bfa 100644 --- a/awx/lib/site-packages/rest_framework/runtests/settings.py +++ b/awx/lib/site-packages/rest_framework/runtests/settings.py @@ -134,6 +134,8 @@ PASSWORD_HASHERS = ( 'django.contrib.auth.hashers.CryptPasswordHasher', ) +AUTH_USER_MODEL = 'auth.User' + import django if django.VERSION < (1, 3): diff --git a/awx/lib/site-packages/rest_framework/serializers.py b/awx/lib/site-packages/rest_framework/serializers.py index 11ead02e4f..31cfa34474 100644 --- a/awx/lib/site-packages/rest_framework/serializers.py +++ b/awx/lib/site-packages/rest_framework/serializers.py @@ -683,14 +683,14 @@ class ModelSerializer(Serializer): # in the `read_only_fields` option for field_name in self.opts.read_only_fields: assert field_name not in self.base_fields.keys(), \ - "field '%s' on serializer '%s' specfied in " \ + "field '%s' on serializer '%s' specified in " \ "`read_only_fields`, but also added " \ - "as an explict field. Remove it from `read_only_fields`." % \ + "as an explicit field. Remove it from `read_only_fields`." % \ (field_name, self.__class__.__name__) assert field_name in ret, \ - "Noexistant field '%s' specified in `read_only_fields` " \ + "Non-existant field '%s' specified in `read_only_fields` " \ "on serializer '%s'." % \ - (self.__class__.__name__, field_name) + (field_name, self.__class__.__name__) ret[field_name].read_only = True return ret @@ -904,34 +904,23 @@ class HyperlinkedModelSerializer(ModelSerializer): _default_view_name = '%(model_name)s-detail' _hyperlink_field_class = HyperlinkedRelatedField - # Just a placeholder to ensure 'url' is the first field - # The field itself is actually created on initialization, - # when the view_name and lookup_field arguments are available. - url = Field() - - def __init__(self, *args, **kwargs): - super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) + def get_default_fields(self): + fields = super(HyperlinkedModelSerializer, self).get_default_fields() if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - url_field = HyperlinkedIdentityField( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field - ) - url_field.initialize(self, 'url') - self.fields['url'] = url_field + if 'url' not in fields: + url_field = HyperlinkedIdentityField( + view_name=self.opts.view_name, + lookup_field=self.opts.lookup_field + ) + ret = self._dict_class() + ret['url'] = url_field + ret.update(fields) + fields = ret - def _get_default_view_name(self, model): - """ - Return the view name to use if 'view_name' is not specified in 'Meta' - """ - model_meta = model._meta - format_kwargs = { - 'app_label': model_meta.app_label, - 'model_name': model_meta.object_name.lower() - } - return self._default_view_name % format_kwargs + return fields def get_pk_field(self, model_field): if self.opts.fields and model_field.name in self.opts.fields: @@ -966,3 +955,14 @@ class HyperlinkedModelSerializer(ModelSerializer): return data.get('url', None) except AttributeError: return None + + def _get_default_view_name(self, model): + """ + Return the view name to use if 'view_name' is not specified in 'Meta' + """ + model_meta = model._meta + format_kwargs = { + 'app_label': model_meta.app_label, + 'model_name': model_meta.object_name.lower() + } + return self._default_view_name % format_kwargs diff --git a/awx/lib/site-packages/rest_framework/settings.py b/awx/lib/site-packages/rest_framework/settings.py index beb511aca2..8fd177d586 100644 --- a/awx/lib/site-packages/rest_framework/settings.py +++ b/awx/lib/site-packages/rest_framework/settings.py @@ -73,6 +73,13 @@ DEFAULTS = { 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, + # Testing + 'TEST_REQUEST_RENDERER_CLASSES': ( + 'rest_framework.renderers.MultiPartRenderer', + 'rest_framework.renderers.JSONRenderer' + ), + 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', + # Browser enhancements 'FORM_METHOD_OVERRIDE': '_method', 'FORM_CONTENT_OVERRIDE': '_content', @@ -115,6 +122,7 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', + 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) diff --git a/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html b/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html index 9d939e738b..51f9c2916b 100644 --- a/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html +++ b/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html @@ -196,7 +196,7 @@ <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button> {% endif %} {% if raw_data_patch_form %} - <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PUT request on the {{ name }} resource">PATCH</button> + <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PATCH request on the {{ name }} resource">PATCH</button> {% endif %} </div> </fieldset> diff --git a/awx/lib/site-packages/rest_framework/test.py b/awx/lib/site-packages/rest_framework/test.py new file mode 100644 index 0000000000..a18f5a2938 --- /dev/null +++ b/awx/lib/site-packages/rest_framework/test.py @@ -0,0 +1,157 @@ +# -- coding: utf-8 -- + +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order +# to make it harder for the user to import the wrong thing without realizing. +from __future__ import unicode_literals +import django +from django.conf import settings +from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler +from django.test import testcases +from rest_framework.settings import api_settings +from rest_framework.compat import RequestFactory as DjangoRequestFactory +from rest_framework.compat import force_bytes_or_smart_bytes, six + + +def force_authenticate(request, user=None, token=None): + request._force_auth_user = user + request._force_auth_token = token + + +class APIRequestFactory(DjangoRequestFactory): + renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES + default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + + def __init__(self, enforce_csrf_checks=False, **defaults): + self.enforce_csrf_checks = enforce_csrf_checks + self.renderer_classes = {} + for cls in self.renderer_classes_list: + self.renderer_classes[cls.format] = cls + super(APIRequestFactory, self).__init__(**defaults) + + def _encode_data(self, data, format=None, content_type=None): + """ + Encode the data returning a two tuple of (bytes, content_type) + """ + + if not data: + return ('', None) + + assert format is None or content_type is None, ( + 'You may not set both `format` and `content_type`.' + ) + + if content_type: + # Content type specified explicitly, treat data as a raw bytestring + ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + + else: + format = format or self.default_format + + assert format in self.renderer_classes, ("Invalid format '{0}'. " + "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " + "to enable extra request formats.".format( + format, + ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) + ) + ) + + # Use format and render the data into a bytestring + renderer = self.renderer_classes[format]() + ret = renderer.render(data) + + # Determine the content-type header from the renderer + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) + + # Coerce text to bytes if required. + if isinstance(ret, six.text_type): + ret = bytes(ret.encode(renderer.charset)) + + return ret, content_type + + def post(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('POST', path, data, content_type, **extra) + + def put(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PUT', path, data, content_type, **extra) + + def patch(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PATCH', path, data, content_type, **extra) + + def delete(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('DELETE', path, data, content_type, **extra) + + def options(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('OPTIONS', path, data, content_type, **extra) + + def request(self, **kwargs): + request = super(APIRequestFactory, self).request(**kwargs) + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + return request + + +class ForceAuthClientHandler(ClientHandler): + """ + A patched version of ClientHandler that can enforce authentication + on the outgoing requests. + """ + + def __init__(self, *args, **kwargs): + self._force_user = None + self._force_token = None + super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + + def get_response(self, request): + # This is the simplest place we can hook into to patch the + # request object. + force_authenticate(request, self._force_user, self._force_token) + return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, enforce_csrf_checks=False, **defaults): + super(APIClient, self).__init__(**defaults) + self.handler = ForceAuthClientHandler(enforce_csrf_checks) + self._credentials = {} + + def credentials(self, **kwargs): + """ + Sets headers that will be used on every outgoing request. + """ + self._credentials = kwargs + + def force_authenticate(self, user=None, token=None): + """ + Forcibly authenticates outgoing requests with the given + user and/or token. + """ + self.handler._force_user = user + self.handler._force_token = token + + def request(self, **kwargs): + # Ensure that any credentials set get added to every request. + kwargs.update(self._credentials) + return super(APIClient, self).request(**kwargs) + + +class APITransactionTestCase(testcases.TransactionTestCase): + client_class = APIClient + + +class APITestCase(testcases.TestCase): + client_class = APIClient + + +if django.VERSION >= (1, 4): + class APISimpleTestCase(testcases.SimpleTestCase): + client_class = APIClient + + class APILiveServerTestCase(testcases.LiveServerTestCase): + client_class = APIClient diff --git a/awx/lib/site-packages/rest_framework/tests/description.py b/awx/lib/site-packages/rest_framework/tests/description.py new file mode 100644 index 0000000000..b46d7f54d8 --- /dev/null +++ b/awx/lib/site-packages/rest_framework/tests/description.py @@ -0,0 +1,26 @@ +# -- coding: utf-8 -- + +# Apparently there is a python 2.6 issue where docstrings of imported view classes +# do not retain their encoding information even if a module has a proper +# encoding declaration at the top of its source file. Therefore for tests +# to catch unicode related errors, a mock view has to be declared in a separate +# module. + +from rest_framework.views import APIView + + +# test strings snatched from http://www.columbia.edu/~fdc/utf8/, +# http://winrus.com/utf8-jap.htm and memory +UTF8_TEST_DOCSTRING = ( + 'zażółć gęślą jaźń' + 'Sîne klâwen durh die wolken sint geslagen' + 'Τη γλώσσα μου έδωσαν ελληνική' + 'யாமறிந்த மொழிகளிலே தமிழ்மொழி' + 'На берегу пустынных волн' + 'てすと' + 'アイウエオカキクケコサシスセソタチツテ' +) + + +class ViewWithNonASCIICharactersInDocstring(APIView): + __doc__ = UTF8_TEST_DOCSTRING diff --git a/awx/lib/site-packages/rest_framework/tests/models.py b/awx/lib/site-packages/rest_framework/tests/models.py index e2d4eacdc9..1598ecd94a 100644 --- a/awx/lib/site-packages/rest_framework/tests/models.py +++ b/awx/lib/site-packages/rest_framework/tests/models.py @@ -52,7 +52,7 @@ class CallableDefaultValueModel(RESTFrameworkModel): class ManyToManyModel(RESTFrameworkModel): - rel = models.ManyToManyField(Anchor) + rel = models.ManyToManyField(Anchor, help_text='Some help text.') class ReadOnlyManyToManyModel(RESTFrameworkModel): diff --git a/awx/lib/site-packages/rest_framework/tests/test_authentication.py b/awx/lib/site-packages/rest_framework/tests/test_authentication.py index d46ac07985..a44813b691 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_authentication.py +++ b/awx/lib/site-packages/rest_framework/tests/test_authentication.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.http import HttpResponse -from django.test import Client, TestCase +from django.test import TestCase from django.utils import unittest from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions @@ -21,14 +21,13 @@ from rest_framework.authtoken.models import Token from rest_framework.compat import patterns, url, include from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView -import json import base64 import time import datetime -factory = RequestFactory() +factory = APIRequestFactory() class MockView(APIView): @@ -68,7 +67,7 @@ class BasicAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -87,7 +86,7 @@ class BasicAuthTests(TestCase): credentials = ('%s:%s' % (self.username, self.password)) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) auth = 'Basic %s' % base64_credentials - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_basic_auth(self): @@ -97,7 +96,7 @@ class BasicAuthTests(TestCase): def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') @@ -107,8 +106,8 @@ class SessionAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) - self.non_csrf_client = Client(enforce_csrf_checks=False) + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.non_csrf_client = APIClient(enforce_csrf_checks=False) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -154,7 +153,7 @@ class TokenAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -172,7 +171,7 @@ class TokenAuthTests(TestCase): def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_token_auth(self): @@ -182,7 +181,7 @@ class TokenAuthTests(TestCase): def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_token_has_auto_assigned_key_if_none_provided(self): @@ -193,33 +192,33 @@ class TokenAuthTests(TestCase): def test_token_login_json(self): """Ensure token login view using JSON POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': self.password}), 'application/json') + {'username': self.username, 'password': self.password}, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) def test_token_login_json_bad_creds(self): """Ensure token login view using JSON POST fails if bad credentials are used.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') + {'username': self.username, 'password': "badpass"}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_json_missing_fields(self): """Ensure token login view using JSON POST fails if missing fields.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username}), 'application/json') + {'username': self.username}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_form(self): """Ensure token login view using form POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', {'username': self.username, 'password': self.password}) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) class IncorrectCredentialsTests(TestCase): @@ -256,7 +255,7 @@ class OAuthTests(TestCase): self.consts = consts - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -428,13 +427,55 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth-with-scope/', params) self.assertEqual(response.status_code, 200) + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_bad_consumer_key(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': 'badconsumerkey' + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_bad_token_key(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': 'badtokenkey', + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' diff --git a/awx/lib/site-packages/rest_framework/tests/test_decorators.py b/awx/lib/site-packages/rest_framework/tests/test_decorators.py index 1016fed3ff..195f0ba3e4 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_decorators.py +++ b/awx/lib/site-packages/rest_framework/tests/test_decorators.py @@ -1,12 +1,13 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework import status +from rest_framework.authentication import BasicAuthentication +from rest_framework.parsers import JSONParser +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.renderers import JSONRenderer -from rest_framework.parsers import JSONParser -from rest_framework.authentication import BasicAuthentication +from rest_framework.test import APIRequestFactory from rest_framework.throttling import UserRateThrottle -from rest_framework.permissions import IsAuthenticated from rest_framework.views import APIView from rest_framework.decorators import ( api_view, @@ -17,13 +18,11 @@ from rest_framework.decorators import ( permission_classes, ) -from rest_framework.tests.utils import RequestFactory - class DecoratorTestCase(TestCase): def setUp(self): - self.factory = RequestFactory() + self.factory = APIRequestFactory() def _finalize_response(self, request, response, *args, **kwargs): response.request = request diff --git a/awx/lib/site-packages/rest_framework/tests/test_description.py b/awx/lib/site-packages/rest_framework/tests/test_description.py index 52c1a34c10..8019f5ecaf 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_description.py +++ b/awx/lib/site-packages/rest_framework/tests/test_description.py @@ -2,8 +2,10 @@ from __future__ import unicode_literals from django.test import TestCase +from rest_framework.compat import apply_markdown, smart_text from rest_framework.views import APIView -from rest_framework.compat import apply_markdown +from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring +from rest_framework.tests.description import UTF8_TEST_DOCSTRING from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. @@ -83,11 +85,10 @@ class TestViewNamesAndDescriptions(TestCase): Unicode in docstrings should be respected. """ - class MockView(APIView): - """Проверка""" - pass - - self.assertEqual(get_view_description(MockView), "Проверка") + self.assertEqual( + get_view_description(ViewWithNonASCIICharactersInDocstring), + smart_text(UTF8_TEST_DOCSTRING) + ) def test_view_description_can_be_empty(self): """ diff --git a/awx/lib/site-packages/rest_framework/tests/test_fields.py b/awx/lib/site-packages/rest_framework/tests/test_fields.py index de3710011c..6836ec86f6 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_fields.py +++ b/awx/lib/site-packages/rest_framework/tests/test_fields.py @@ -865,4 +865,34 @@ class FieldCallableDefault(TestCase): field = serializers.WritableField(default=self.simple_callable) into = {} field.field_from_native({}, {}, 'field', into) - self.assertEquals(into, {'field': 'foo bar'}) + self.assertEqual(into, {'field': 'foo bar'}) + + +class CustomIntegerField(TestCase): + """ + Test that custom fields apply min_value and max_value constraints + """ + def test_custom_fields_can_be_validated_for_value(self): + + class MoneyField(models.PositiveIntegerField): + pass + + class EntryModel(models.Model): + bank = MoneyField(validators=[validators.MaxValueValidator(100)]) + + class EntrySerializer(serializers.ModelSerializer): + class Meta: + model = EntryModel + + entry = EntryModel(bank=1) + + serializer = EntrySerializer(entry, data={"bank": 11}) + self.assertTrue(serializer.is_valid()) + + serializer = EntrySerializer(entry, data={"bank": -1}) + self.assertFalse(serializer.is_valid()) + + serializer = EntrySerializer(entry, data={"bank": 101}) + self.assertFalse(serializer.is_valid()) + + diff --git a/awx/lib/site-packages/rest_framework/tests/test_filters.py b/awx/lib/site-packages/rest_framework/tests/test_filters.py index aaed624782..c9d9e7ffaa 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_filters.py +++ b/awx/lib/site-packages/rest_framework/tests/test_filters.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): diff --git a/awx/lib/site-packages/rest_framework/tests/test_generics.py b/awx/lib/site-packages/rest_framework/tests/test_generics.py index 37734195aa..1550880b56 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_generics.py +++ b/awx/lib/site-packages/rest_framework/tests/test_generics.py @@ -3,12 +3,11 @@ from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, renderers, serializers, status -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six -import json -factory = RequestFactory() +factory = APIRequestFactory() class RootView(generics.ListCreateAPIView): @@ -71,9 +70,8 @@ class TestRootView(TestCase): """ POST requests to ListCreateAPIView should create a new object. """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -85,9 +83,8 @@ class TestRootView(TestCase): """ PUT requests to ListCreateAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.put('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -148,9 +145,8 @@ class TestRootView(TestCase): """ POST requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -189,9 +185,8 @@ class TestInstanceView(TestCase): """ POST requests to RetrieveUpdateDestroyAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -201,9 +196,8 @@ class TestInstanceView(TestCase): """ PUT requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk='1').render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -215,9 +209,8 @@ class TestInstanceView(TestCase): """ PATCH requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.patch('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.patch('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() @@ -293,9 +286,8 @@ class TestInstanceView(TestCase): """ PUT requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -309,9 +301,8 @@ class TestInstanceView(TestCase): if it does not currently exist. """ self.objects.get(id=1).delete() - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -324,10 +315,9 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if it doesn't exist. """ - content = {'text': 'foobar'} + data = {'text': 'foobar'} # pk fields can not be created on demand, only the database can set the pk for a new object - request = factory.put('/5', json.dumps(content), - content_type='application/json') + request = factory.put('/5', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=5).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -339,9 +329,8 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. """ - content = {'text': 'foobar'} - request = factory.put('/test_slug', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/test_slug', data, format='json') with self.assertNumQueries(2): response = self.slug_based_view(request, slug='test_slug').render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase): https://github.com/tomchristie/django-rest-framework/issues/285 """ - content = {'email': 'foobar@example.com', 'content': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'email': 'foobar@example.com', 'content': 'foobar'} + request = factory.post('/', data, format='json') response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) created = self.objects.get(id=1) diff --git a/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py b/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py index 1894ddb27e..61e613d75e 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py +++ b/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,12 +1,15 @@ from __future__ import unicode_literals import json from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import generics, status, serializers from rest_framework.compat import patterns, url -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel +from rest_framework.test import APIRequestFactory +from rest_framework.tests.models import ( + Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, + Album, Photo, OptionalRelationModel +) -factory = RequestFactory() +factory = APIRequestFactory() class BlogPostCommentSerializer(serializers.ModelSerializer): @@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer): class PhotoSerializer(serializers.Serializer): description = serializers.CharField() - album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') + album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title') def restore_object(self, attrs, instance=None): return Photo(**attrs) @@ -301,3 +304,30 @@ class TestOptionalRelationHyperlinkedView(TestCase): data=json.dumps(self.data), content_type='application/json') self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class TestOverriddenURLField(TestCase): + def setUp(self): + class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.SerializerMethodField('get_url') + + class Meta: + model = BlogPost + fields = ('title', 'url') + + def get_url(self, obj): + return 'foo bar' + + self.Serializer = OverriddenURLSerializer + self.obj = BlogPost.objects.create(title='New blog post') + + def test_overridden_url_field(self): + """ + The 'url' field should respect overriding. + Regression test for #936. + """ + serializer = self.Serializer(self.obj) + self.assertEqual( + serializer.data, + {'title': 'New blog post', 'url': 'foo bar'} + ) diff --git a/awx/lib/site-packages/rest_framework/tests/test_negotiation.py b/awx/lib/site-packages/rest_framework/tests/test_negotiation.py index 7f84827f0e..04b89eb600 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_negotiation.py +++ b/awx/lib/site-packages/rest_framework/tests/test_negotiation.py @@ -1,12 +1,12 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.request import Request from rest_framework.renderers import BaseRenderer +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() class MockJSONRenderer(BaseRenderer): diff --git a/awx/lib/site-packages/rest_framework/tests/test_pagination.py b/awx/lib/site-packages/rest_framework/tests/test_pagination.py index e538a78e5b..85d4640ead 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_pagination.py +++ b/awx/lib/site-packages/rest_framework/tests/test_pagination.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.paginator import Paginator from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): @@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase): self.page = paginator.page(1) def test_custom_pagination_serializer(self): - request = RequestFactory().get('/foobar') + request = APIRequestFactory().get('/foobar') serializer = CustomPaginationSerializer( instance=self.page, context={'request': request} diff --git a/awx/lib/site-packages/rest_framework/tests/test_permissions.py b/awx/lib/site-packages/rest_framework/tests/test_permissions.py index 6caaf65b02..e2cca3808c 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_permissions.py +++ b/awx/lib/site-packages/rest_framework/tests/test_permissions.py @@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission from django.db import models from django.test import TestCase from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory import base64 -import json -factory = RequestFactory() +factory = APIRequestFactory() class BasicModel(models.Model): @@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase): BasicModel(text='foo').save() def test_has_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_has_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_does_not_have_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_does_not_have_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase): def test_has_put_as_create_permissions(self): # User only has update permissions - should be able to update an entity. - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) # But if PUTing to a new entity, permission should be denied. - request = factory.put('/2', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/2', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='2') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_options_permitted(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn('actions', response.data) self.assertEqual(list(response.data['actions'].keys()), ['POST']) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(list(response.data['actions'].keys()), ['PUT']) def test_options_disallowed(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) def test_options_updateonly(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.updateonly_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py b/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py index 2ca7f4f2b3..3c4d39af63 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py +++ b/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py @@ -1,15 +1,15 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import ( BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) -factory = RequestFactory() +factory = APIRequestFactory() request = factory.get('/') # Just to ensure we have a request in the serializer context diff --git a/awx/lib/site-packages/rest_framework/tests/test_renderers.py b/awx/lib/site-packages/rest_framework/tests/test_renderers.py index 95b597411f..df6f4aa63f 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_renderers.py +++ b/awx/lib/site-packages/rest_framework/tests/test_renderers.py @@ -4,19 +4,17 @@ from __future__ import unicode_literals from decimal import Decimal from django.core.cache import cache from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions -from rest_framework.compat import yaml, etree, patterns, url, include +from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings -from rest_framework.compat import StringIO -from rest_framework.compat import six +from rest_framework.test import APIRequestFactory import datetime import pickle import re @@ -121,7 +119,7 @@ class POSTDeniedView(APIView): class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self): view = POSTDeniedView.as_view() - request = RequestFactory().get('/') + request = APIRequestFactory().get('/') response = view(request).render() self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<') diff --git a/awx/lib/site-packages/rest_framework/tests/test_request.py b/awx/lib/site-packages/rest_framework/tests/test_request.py index a5c5e84ce7..969d8024a9 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_request.py +++ b/awx/lib/site-packages/rest_framework/tests/test_request.py @@ -5,8 +5,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware -from django.test import TestCase, Client -from django.test.client import RequestFactory +from django.test import TestCase from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns @@ -19,12 +18,13 @@ from rest_framework.parsers import ( from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView from rest_framework.compat import six import json -factory = RequestFactory() +factory = APIRequestFactory() class PlainTextParser(BaseParser): @@ -116,16 +116,7 @@ class TestContentParsing(TestCase): Ensure request.DATA returns content for PUT request with form content. """ data = {'qwerty': 'uiop'} - - from django import VERSION - - if VERSION >= (1, 5): - from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart - request = Request(factory.put('/', encode_multipart(BOUNDARY, data), - content_type=MULTIPART_CONTENT)) - else: - request = Request(factory.put('/', data)) - + request = Request(factory.put('/', data)) request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(list(request.DATA.items()), list(data.items())) @@ -257,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase): urls = 'rest_framework.tests.test_request' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' diff --git a/awx/lib/site-packages/rest_framework/tests/test_reverse.py b/awx/lib/site-packages/rest_framework/tests/test_reverse.py index 93ef563776..690a30b119 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_reverse.py +++ b/awx/lib/site-packages/rest_framework/tests/test_reverse.py @@ -1,10 +1,10 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.compat import patterns, url from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() def null_view(request): diff --git a/awx/lib/site-packages/rest_framework/tests/test_routers.py b/awx/lib/site-packages/rest_framework/tests/test_routers.py index 10d3cc25a0..5fcccb7414 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_routers.py +++ b/awx/lib/site-packages/rest_framework/tests/test_routers.py @@ -1,14 +1,15 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase -from django.test.client import RequestFactory -from rest_framework import serializers, viewsets +from django.core.exceptions import ImproperlyConfigured +from rest_framework import serializers, viewsets, permissions from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action from rest_framework.response import Response -from rest_framework.routers import SimpleRouter +from rest_framework.routers import SimpleRouter, DefaultRouter +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() urlpatterns = patterns('',) @@ -50,7 +51,7 @@ class TestSimpleRouter(TestCase): route = decorator_routes[i] # check url listing self.assertEqual(route.url, - '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) + '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) # check method to function mapping if endpoint == 'action3': methods_map = ['post', 'delete'] @@ -103,7 +104,7 @@ class TestCustomLookupFields(TestCase): def test_retrieve_lookup_field_list_view(self): response = self.client.get('/notes/') - self.assertEquals(response.data, + self.assertEqual(response.data, [{ "url": "http://testserver/notes/123/", "uuid": "123", "text": "foo bar" @@ -112,10 +113,104 @@ class TestCustomLookupFields(TestCase): def test_retrieve_lookup_field_detail_view(self): response = self.client.get('/notes/123/') - self.assertEquals(response.data, + self.assertEqual(response.data, { "url": "http://testserver/notes/123/", "uuid": "123", "text": "foo bar" } ) + +class TestTrailingSlashIncluded(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_urls_have_trailing_slash_by_default(self): + expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$'] + for idx in range(len(expected)): + self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestTrailingSlashRemoved(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + + self.router = SimpleRouter(trailing_slash=False) + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_urls_can_have_trailing_slash_removed(self): + expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] + for idx in range(len(expected)): + self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestNameableRoot(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + self.router = DefaultRouter() + self.router.root_view_name = 'nameable-root' + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_router_has_custom_name(self): + expected = 'nameable-root' + self.assertEqual(expected, self.urls[0].name) + + +class TestActionKeywordArgs(TestCase): + """ + Ensure keyword arguments passed in the `@action` decorator + are properly handled. Refs #940. + """ + + def setUp(self): + class TestViewSet(viewsets.ModelViewSet): + permission_classes = [] + + @action(permission_classes=[permissions.AllowAny]) + def custom(self, request, *args, **kwargs): + return Response({ + 'permission_classes': self.permission_classes + }) + + self.router = SimpleRouter() + self.router.register(r'test', TestViewSet, base_name='test') + self.view = self.router.urls[-1].callback + + def test_action_kwargs(self): + request = factory.post('/test/0/custom/') + response = self.view(request) + self.assertEqual( + response.data, + {'permission_classes': [permissions.AllowAny]} + ) + + +class TestActionAppliedToExistingRoute(TestCase): + """ + Ensure `@action` decorator raises an except when applied + to an existing route + """ + + def test_exception_raised_when_action_applied_to_existing_route(self): + class TestViewSet(viewsets.ModelViewSet): + + @action() + def retrieve(self, request, *args, **kwargs): + return Response({ + 'hello': 'world' + }) + + self.router = SimpleRouter() + self.router.register(r'test', TestViewSet, base_name='test') + + with self.assertRaises(ImproperlyConfigured): + self.router.urls diff --git a/awx/lib/site-packages/rest_framework/tests/test_serializer.py b/awx/lib/site-packages/rest_framework/tests/test_serializer.py index 6cc913c5cd..c24976603e 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_serializer.py +++ b/awx/lib/site-packages/rest_framework/tests/test_serializer.py @@ -494,7 +494,7 @@ class CustomValidationTests(TestCase): } serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'email': ['Enter a valid e-mail address.']}) + self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']}) class PositiveIntegerAsChoiceTests(TestCase): @@ -1376,6 +1376,18 @@ class FieldLabelTest(TestCase): self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) +# Test for issue #961 + +class ManyFieldHelpTextTest(TestCase): + def test_help_text_no_hold_down_control_msg(self): + """ + Validate that help_text doesn't contain the 'Hold down "Control" ...' + message that Django appends to choice fields. + """ + rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text) + self.assertEqual('Some help text.', rel_field.help_text) + + class AttributeMappingOnAutogeneratedFieldsTests(TestCase): def setUp(self): @@ -1556,3 +1568,78 @@ class MetadataSerializerTestCase(TestCase): } } self.assertEqual(expected, metadata) + + +### Regression test for #840 + +class SimpleModel(models.Model): + text = models.CharField(max_length=100) + + +class SimpleModelSerializer(serializers.ModelSerializer): + text = serializers.CharField() + other = serializers.CharField() + + class Meta: + model = SimpleModel + + def validate_other(self, attrs, source): + del attrs['other'] + return attrs + + +class FieldValidationRemovingAttr(TestCase): + def test_removing_non_model_field_in_validation(self): + """ + Removing an attr during field valiation should ensure that it is not + passed through when restoring the object. + + This allows additional non-model fields to be supported. + + Regression test for #840. + """ + serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'}) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEqual(serializer.object.text, 'foo') + + +### Regression test for #878 + +class SimpleTargetModel(models.Model): + text = models.CharField(max_length=100) + + +class SimplePKSourceModelSerializer(serializers.Serializer): + targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True) + text = serializers.CharField() + + +class SimpleSlugSourceModelSerializer(serializers.Serializer): + targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk') + text = serializers.CharField() + + +class SerializerSupportsManyRelationships(TestCase): + def setUp(self): + SimpleTargetModel.objects.create(text='foo') + SimpleTargetModel.objects.create(text='bar') + + def test_serializer_supports_pk_many_relationships(self): + """ + Regression test for #878. + + Note that pk behavior has a different code path to usual cases, + for performance reasons. + """ + serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) + + def test_serializer_supports_slug_many_relationships(self): + """ + Regression test for #878. + """ + serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) diff --git a/awx/lib/site-packages/rest_framework/tests/test_testing.py b/awx/lib/site-packages/rest_framework/tests/test_testing.py new file mode 100644 index 0000000000..49d45fc292 --- /dev/null +++ b/awx/lib/site-packages/rest_framework/tests/test_testing.py @@ -0,0 +1,115 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.contrib.auth.models import User +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate + + +@api_view(['GET', 'POST']) +def view(request): + return Response({ + 'auth': request.META.get('HTTP_AUTHORIZATION', b''), + 'user': request.user.username + }) + + +urlpatterns = patterns('', + url(r'^view/$', view), +) + + +class TestAPITestClient(TestCase): + urls = 'rest_framework.tests.test_testing' + + def setUp(self): + self.client = APIClient() + + def test_credentials(self): + """ + Setting `.credentials()` adds the required headers to each request. + """ + self.client.credentials(HTTP_AUTHORIZATION='example') + for _ in range(0, 3): + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') + + def test_force_authenticate(self): + """ + Setting `.force_authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.force_authenticate(user) + response = self.client.get('/view/') + self.assertEqual(response.data['user'], 'example') + + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + User.objects.create_user('example', 'example@example.com', 'password') + self.client.login(username='example', password='password') + response = self.client.post('/view/') + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + client = APIClient(enforce_csrf_checks=True) + User.objects.create_user('example', 'example@example.com', 'password') + client.login(username='example', password='password') + response = client.post('/view/') + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) + + +class TestAPIRequestFactory(TestCase): + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory() + request = factory.post('/view/') + request.user = user + response = view(request) + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory(enforce_csrf_checks=True) + request = factory.post('/view/') + request.user = user + response = view(request) + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) + + def test_invalid_format(self): + """ + Attempting to use a format that is not configured will raise an + assertion error. + """ + factory = APIRequestFactory() + self.assertRaises(AssertionError, factory.post, + path='/view/', data={'example': 1}, format='xml' + ) + + def test_force_authenticate(self): + """ + Setting `force_authenticate()` forcibly authenticates the request. + """ + user = User.objects.create_user('example', 'example@example.com') + factory = APIRequestFactory() + request = factory.get('/view') + force_authenticate(request, user=user) + response = view(request) + self.assertEqual(response.data['user'], 'example') diff --git a/awx/lib/site-packages/rest_framework/tests/test_throttling.py b/awx/lib/site-packages/rest_framework/tests/test_throttling.py index da400b2fcd..41bff6926a 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_throttling.py +++ b/awx/lib/site-packages/rest_framework/tests/test_throttling.py @@ -5,9 +5,9 @@ from __future__ import unicode_literals from django.test import TestCase from django.contrib.auth.models import User from django.core.cache import cache -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -from rest_framework.throttling import UserRateThrottle +from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -21,6 +21,14 @@ class User3MinRateThrottle(UserRateThrottle): scope = 'minutes' +class NonTimeThrottle(BaseThrottle): + def allow_request(self, request, view): + if not hasattr(self.__class__, 'called'): + self.__class__.called = True + return True + return False + + class MockView(APIView): throttle_classes = (User3SecRateThrottle,) @@ -35,15 +43,20 @@ class MockView_MinuteThrottling(APIView): return Response('foo') -class ThrottlingTests(TestCase): - urls = 'rest_framework.tests.test_throttling' +class MockView_NonTimeThrottling(APIView): + throttle_classes = (NonTimeThrottle,) + + def get(self, request): + return Response('foo') + +class ThrottlingTests(TestCase): def setUp(self): """ Reset the cache so that no throttles will be active """ cache.clear() - self.factory = RequestFactory() + self.factory = APIRequestFactory() def test_requests_are_throttled(self): """ @@ -141,3 +154,124 @@ class ThrottlingTests(TestCase): (60, None), (80, None) )) + + def test_non_time_throttle(self): + """ + Ensure for second based throttles. + """ + request = self.factory.get('/') + + self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) + + response = MockView_NonTimeThrottling.as_view()(request) + self.assertFalse('X-Throttle-Wait-Seconds' in response) + + self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) + + response = MockView_NonTimeThrottling.as_view()(request) + self.assertFalse('X-Throttle-Wait-Seconds' in response) + + +class ScopedRateThrottleTests(TestCase): + """ + Tests for ScopedRateThrottle. + """ + + def setUp(self): + class XYScopedRateThrottle(ScopedRateThrottle): + TIMER_SECONDS = 0 + THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} + timer = lambda self: self.TIMER_SECONDS + + class XView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'x' + + def get(self, request): + return Response('x') + + class YView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'y' + + def get(self, request): + return Response('y') + + class UnscopedView(APIView): + throttle_classes = (XYScopedRateThrottle,) + + def get(self, request): + return Response('y') + + self.throttle_class = XYScopedRateThrottle + self.factory = APIRequestFactory() + self.x_view = XView.as_view() + self.y_view = YView.as_view() + self.unscoped_view = UnscopedView.as_view() + + def increment_timer(self, seconds=1): + self.throttle_class.TIMER_SECONDS += seconds + + def test_scoped_rate_throttle(self): + request = self.factory.get('/') + + # Should be able to hit x view 3 times per minute. + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(429, response.status_code) + + # Should be able to hit y view 1 time per minute. + self.increment_timer() + response = self.y_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.y_view(request) + self.assertEqual(429, response.status_code) + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(55) + + # Should still be able to hit x view 3 times per minute. + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(429, response.status_code) + + # Should still be able to hit y view 1 time per minute. + self.increment_timer() + response = self.y_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.y_view(request) + self.assertEqual(429, response.status_code) + + def test_unscoped_view_not_throttled(self): + request = self.factory.get('/') + + for idx in range(10): + self.increment_timer() + response = self.unscoped_view(request) + self.assertEqual(200, response.status_code) diff --git a/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py b/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py index 29ed4a961c..8132ec4c8e 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py +++ b/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from collections import namedtuple from django.core import urlresolvers from django.test import TestCase -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.compat import patterns, url, include from rest_framework.urlpatterns import format_suffix_patterns @@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase): Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. """ def _resolve_urlpatterns(self, urlpatterns, test_paths): - factory = RequestFactory() + factory = APIRequestFactory() try: urlpatterns = format_suffix_patterns(urlpatterns) except Exception: diff --git a/awx/lib/site-packages/rest_framework/tests/test_validation.py b/awx/lib/site-packages/rest_framework/tests/test_validation.py index a6ec0e993d..ebfdff9cd1 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_validation.py +++ b/awx/lib/site-packages/rest_framework/tests/test_validation.py @@ -2,10 +2,9 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status -from rest_framework.tests.utils import RequestFactory -import json +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() # Regression for #666 @@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase): validation on read only fields. """ obj = ValidationModel.objects.create(blank_validated_field='') - request = factory.put('/', json.dumps({}), - content_type='application/json') + request = factory.put('/', {}, format='json') view = UpdateValidationModel().as_view() response = view(request, pk=obj.pk).render() self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/awx/lib/site-packages/rest_framework/tests/test_views.py b/awx/lib/site-packages/rest_framework/tests/test_views.py index 2767d24c80..c0bec5aed1 100644 --- a/awx/lib/site-packages/rest_framework/tests/test_views.py +++ b/awx/lib/site-packages/rest_framework/tests/test_views.py @@ -1,17 +1,15 @@ from __future__ import unicode_literals import copy - from django.test import TestCase -from django.test.client import RequestFactory - from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -factory = RequestFactory() +factory = APIRequestFactory() class BasicView(APIView): diff --git a/awx/lib/site-packages/rest_framework/tests/utils.py b/awx/lib/site-packages/rest_framework/tests/utils.py deleted file mode 100644 index 8c87917d92..0000000000 --- a/awx/lib/site-packages/rest_framework/tests/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import unicode_literals -from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory -from django.test.client import MULTIPART_CONTENT -from rest_framework.compat import urlparse - - -class RequestFactory(_RequestFactory): - - def __init__(self, **defaults): - super(RequestFactory, self).__init__(**defaults) - - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - **extra): - "Construct a PATCH request." - - patch_data = self._encode_data(data, content_type) - - parsed = urlparse.urlparse(path) - r = { - 'CONTENT_LENGTH': len(patch_data), - 'CONTENT_TYPE': content_type, - 'PATH_INFO': self._get_path(parsed), - 'QUERY_STRING': parsed[4], - 'REQUEST_METHOD': 'PATCH', - 'wsgi.input': FakePayload(patch_data), - } - r.update(extra) - return self.request(**r) - - -class Client(_Client, RequestFactory): - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - follow=False, **extra): - """ - Send a resource to the server using PATCH. - """ - response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response diff --git a/awx/lib/site-packages/rest_framework/throttling.py b/awx/lib/site-packages/rest_framework/throttling.py index 93ea9816cb..65b4559307 100644 --- a/awx/lib/site-packages/rest_framework/throttling.py +++ b/awx/lib/site-packages/rest_framework/throttling.py @@ -3,7 +3,7 @@ Provides various throttling policies. """ from __future__ import unicode_literals from django.core.cache import cache -from rest_framework import exceptions +from django.core.exceptions import ImproperlyConfigured from rest_framework.settings import api_settings import time @@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle): """ timer = time.time - settings = api_settings cache_format = 'throtte_%(scope)s_%(ident)s' scope = None + THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): @@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle): if not getattr(self, 'scope', None): msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) - raise exceptions.ConfigurationError(msg) + raise ImproperlyConfigured(msg) try: - return self.settings.DEFAULT_THROTTLE_RATES[self.scope] + return self.THROTTLE_RATES[self.scope] except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope - raise exceptions.ConfigurationError(msg) + raise ImproperlyConfigured(msg) def parse_rate(self, rate): """ @@ -96,6 +96,9 @@ class SimpleRateThrottle(BaseThrottle): return True self.key = self.get_cache_key(request, view) + if self.key is None: + return True + self.history = cache.get(self.key, []) self.now = self.timer() @@ -187,6 +190,27 @@ class ScopedRateThrottle(SimpleRateThrottle): """ scope_attr = 'throttle_scope' + def __init__(self): + # Override the usual SimpleRateThrottle, because we can't determine + # the rate until called by the view. + pass + + def allow_request(self, request, view): + # We can only determine the scope once we're called by the view. + self.scope = getattr(view, self.scope_attr, None) + + # If a view does not have a `throttle_scope` always allow the request + if not self.scope: + return True + + # Determine the allowed request rate as we normally would during + # the `__init__` call. + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) + + # We can now proceed as normal. + return super(ScopedRateThrottle, self).allow_request(request, view) + def get_cache_key(self, request, view): """ If `view.throttle_scope` is not set, don't apply this throttle. @@ -194,18 +218,12 @@ class ScopedRateThrottle(SimpleRateThrottle): Otherwise generate the unique cache key by concatenating the user id with the '.throttle_scope` property of the view. """ - scope = getattr(view, self.scope_attr, None) - - if not scope: - # Only throttle views if `.throttle_scope` is set on the view. - return None - if request.user.is_authenticated(): ident = request.user.id else: ident = request.META.get('REMOTE_ADDR', None) return self.cache_format % { - 'scope': scope, + 'scope': self.scope, 'ident': ident } diff --git a/awx/lib/site-packages/rest_framework/utils/formatting.py b/awx/lib/site-packages/rest_framework/utils/formatting.py index ebadb3a670..4bec838776 100644 --- a/awx/lib/site-packages/rest_framework/utils/formatting.py +++ b/awx/lib/site-packages/rest_framework/utils/formatting.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.utils.html import escape from django.utils.safestring import mark_safe -from rest_framework.compat import apply_markdown +from rest_framework.compat import apply_markdown, smart_text import re @@ -63,7 +63,7 @@ def get_view_description(cls, html=False): Return a description for an `APIView` class or `@api_view` function. """ description = cls.__doc__ or '' - description = _remove_leading_indent(description) + description = _remove_leading_indent(smart_text(description)) if html: return markup_description(description) return description diff --git a/awx/lib/site-packages/rest_framework/views.py b/awx/lib/site-packages/rest_framework/views.py index e1b6705b6d..d51233a932 100644 --- a/awx/lib/site-packages/rest_framework/views.py +++ b/awx/lib/site-packages/rest_framework/views.py @@ -4,11 +4,11 @@ Provides an APIView class that is the base of all views in REST framework. from __future__ import unicode_literals from django.core.exceptions import PermissionDenied -from django.http import Http404, HttpResponse +from django.http import Http404 from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View +from rest_framework.compat import View, HttpResponseBase from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -244,9 +244,10 @@ class APIView(View): Returns the final response object. """ # Make the error obvious if a proper response is not returned - assert isinstance(response, HttpResponse), ( - 'Expected a `Response` to be returned from the view, ' - 'but received a `%s`' % type(response) + assert isinstance(response, HttpResponseBase), ( + 'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` ' + 'to be returned from the view, but received a `%s`' + % type(response) ) if isinstance(response, Response): @@ -268,7 +269,7 @@ class APIView(View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ - if isinstance(exc, exceptions.Throttled): + if isinstance(exc, exceptions.Throttled) and exc.wait is not None: # Throttle wait header self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait @@ -304,10 +305,10 @@ class APIView(View): `.dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. """ - request = self.initialize_request(request, *args, **kwargs) - self.request = request self.args = args self.kwargs = kwargs + request = self.initialize_request(request, *args, **kwargs) + self.request = request self.headers = self.default_response_headers # deprecate? try: @@ -341,8 +342,15 @@ class APIView(View): Return a dictionary of metadata about the view. Used to return responses for OPTIONS requests. """ + + # This is used by ViewSets to disambiguate instance vs list views + view_name_suffix = getattr(self, 'suffix', None) + + # By default we can't provide any form-like information, however the + # generic views override this implementation and add additional + # information for POST and PUT methods, based on the serializer. ret = SortedDict() - ret['name'] = get_view_name(self.__class__) + ret['name'] = get_view_name(self.__class__, view_name_suffix) ret['description'] = get_view_description(self.__class__) ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] ret['parses'] = [parser.media_type for parser in self.parser_classes] diff --git a/awx/lib/site-packages/south/__init__.py b/awx/lib/site-packages/south/__init__.py index 20c39178e6..2921dad23d 100644 --- a/awx/lib/site-packages/south/__init__.py +++ b/awx/lib/site-packages/south/__init__.py @@ -2,7 +2,7 @@ South - Useable migrations for Django apps """ -__version__ = "0.8.1" +__version__ = "0.8.2" __authors__ = [ "Andrew Godwin <andrew@aeracode.org>", "Andy McCurdy <andy@andymccurdy.com>" diff --git a/awx/lib/site-packages/south/creator/actions.py b/awx/lib/site-packages/south/creator/actions.py index 37586c23ca..2ffc8ca19f 100644 --- a/awx/lib/site-packages/south/creator/actions.py +++ b/awx/lib/site-packages/south/creator/actions.py @@ -137,12 +137,14 @@ class _NullIssuesField(object): A field that might need to ask a question about rogue NULL values. """ - allow_third_null_option = False + issue_with_backward_migration = False irreversible = False IRREVERSIBLE_TEMPLATE = ''' # User chose to not deal with backwards NULL issues for '%(model_name)s.%(field_name)s' - raise RuntimeError("Cannot reverse this migration. '%(model_name)s.%(field_name)s' and its values cannot be restored.")''' + raise RuntimeError("Cannot reverse this migration. '%(model_name)s.%(field_name)s' and its values cannot be restored.") + + # The following code is provided here to aid in writing a correct migration''' def deal_with_not_null_no_default(self, field, field_def): # If it's a CharField or TextField that's blank, skip this step. @@ -156,17 +158,17 @@ class _NullIssuesField(object): )) print(" ? Since you are %s, you MUST specify a default" % self.null_reason) print(" ? value to use for existing rows. Would you like to:") - print(" ? 1. Quit now, and add a default to the field in models.py") + print(" ? 1. Quit now"+("." if self.issue_with_backward_migration else ", and add a default to the field in models.py" )) print(" ? 2. Specify a one-off value to use for existing columns now") - if self.allow_third_null_option: - print(" ? 3. Disable the backwards migration by raising an exception.") + if self.issue_with_backward_migration: + print(" ? 3. Disable the backwards migration by raising an exception; you can edit the migration to fix it later") while True: choice = raw_input(" ? Please select a choice: ") if choice == "1": sys.exit(1) elif choice == "2": break - elif choice == "3" and self.allow_third_null_option: + elif choice == "3" and self.issue_with_backward_migration: break else: print(" ! Invalid choice.") @@ -266,7 +268,7 @@ class DeleteField(AddField): """ null_reason = "removing this field" - allow_third_null_option = True + issue_with_backward_migration = True def console_line(self): "Returns the string to print on the console, e.g. ' + Added field foo'" @@ -283,7 +285,7 @@ class DeleteField(AddField): if not self.irreversible: return AddField.forwards_code(self) else: - return self.irreversable_code(self.field) + return self.irreversable_code(self.field) + AddField.forwards_code(self) class ChangeField(Action, _NullIssuesField): @@ -315,7 +317,7 @@ class ChangeField(Action, _NullIssuesField): self.deal_with_not_null_no_default(self.new_field, self.new_def) if not self.old_field.null and self.new_field.null and not old_default: self.null_reason = "making this field nullable" - self.allow_third_null_option = True + self.issue_with_backward_migration = True self.deal_with_not_null_no_default(self.old_field, self.old_def) def console_line(self): @@ -353,10 +355,11 @@ class ChangeField(Action, _NullIssuesField): return self._code(self.old_field, self.new_field, self.new_def) def backwards_code(self): + change_code = self._code(self.new_field, self.old_field, self.old_def) if not self.irreversible: - return self._code(self.new_field, self.old_field, self.old_def) + return change_code else: - return self.irreversable_code(self.old_field) + return self.irreversable_code(self.old_field) + change_code class AddUnique(Action): diff --git a/awx/lib/site-packages/south/creator/changes.py b/awx/lib/site-packages/south/creator/changes.py index 9a11d9eb3b..6cdbd19de0 100644 --- a/awx/lib/site-packages/south/creator/changes.py +++ b/awx/lib/site-packages/south/creator/changes.py @@ -309,21 +309,21 @@ class AutoChanges(BaseChanges): old_together = [old_together] if new_together and isinstance(new_together[0], string_types): new_together = [new_together] - old_together = list(map(set, old_together)) - new_together = list(map(set, new_together)) + old_together = frozenset(tuple(o) for o in old_together) + new_together = frozenset(tuple(n) for n in new_together) # See if any appeared or disappeared - for item in old_together: - if item not in new_together: - yield (del_operation, { - "model": self.old_orm[key], - "fields": [self.old_orm[key + ":" + x] for x in item], - }) - for item in new_together: - if item not in old_together: - yield (add_operation, { - "model": self.current_model_from_key(key), - "fields": [self.current_field_from_key(key, x) for x in item], - }) + disappeared = old_together.difference(new_together) + appeared = new_together.difference(old_together) + for item in disappeared: + yield (del_operation, { + "model": self.old_orm[key], + "fields": [self.old_orm[key + ":" + x] for x in item], + }) + for item in appeared: + yield (add_operation, { + "model": self.current_model_from_key(key), + "fields": [self.current_field_from_key(key, x) for x in item], + }) @classmethod def is_triple(cls, triple): diff --git a/awx/lib/site-packages/south/db/__init__.py b/awx/lib/site-packages/south/db/__init__.py index 9927c27f0c..b9b7168e62 100644 --- a/awx/lib/site-packages/south/db/__init__.py +++ b/awx/lib/site-packages/south/db/__init__.py @@ -12,7 +12,8 @@ engine_modules = { 'django.db.backends.mysql': 'mysql', 'mysql_oursql.standard': 'mysql', 'django.db.backends.oracle': 'oracle', - 'sql_server.pyodbc': 'sql_server.pyodbc', #django-pyodbc + 'sql_server.pyodbc': 'sql_server.pyodbc', #django-pyodbc-azure + 'django_pyodbc': 'sql_server.pyodbc', #django-pyodbc 'sqlserver_ado': 'sql_server.pyodbc', #django-mssql 'firebird': 'firebird', #django-firebird 'django.contrib.gis.db.backends.postgis': 'postgresql_psycopg2', diff --git a/awx/lib/site-packages/south/db/firebird.py b/awx/lib/site-packages/south/db/firebird.py index c55a82517f..6216a52cd5 100644 --- a/awx/lib/site-packages/south/db/firebird.py +++ b/awx/lib/site-packages/south/db/firebird.py @@ -71,11 +71,11 @@ class DatabaseOperations(generic.DatabaseOperations): def _alter_set_defaults(self, field, name, params, sqls): "Subcommand of alter_column that sets default values (overrideable)" - # Next, set any default - if not field.null and field.has_default(): - default = field.get_default() - sqls.append(('ALTER COLUMN %s SET DEFAULT %%s ' % (self.quote_name(name),), [default])) - elif self._column_has_default(params): + # Historically, we used to set defaults here. + # But since South 0.8, we don't ever set defaults on alter-column -- we only + # use database-level defaults as scaffolding when adding columns. + # However, we still sometimes need to remove defaults in alter-column. + if self._column_has_default(params): sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), [])) diff --git a/awx/lib/site-packages/south/db/generic.py b/awx/lib/site-packages/south/db/generic.py index 1a26d955b5..5c1935444d 100644 --- a/awx/lib/site-packages/south/db/generic.py +++ b/awx/lib/site-packages/south/db/generic.py @@ -444,12 +444,11 @@ class DatabaseOperations(object): def _alter_set_defaults(self, field, name, params, sqls): "Subcommand of alter_column that sets default values (overrideable)" - # Next, set any default - if not field.null and field.has_default(): - default = field.get_db_prep_save(field.get_default(), connection=self._get_connection()) - sqls.append(('ALTER COLUMN %s SET DEFAULT %%s ' % (self.quote_name(name),), [default])) - else: - sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), [])) + # Historically, we used to set defaults here. + # But since South 0.8, we don't ever set defaults on alter-column -- we only + # use database-level defaults as scaffolding when adding columns. + # However, we still sometimes need to remove defaults in alter-column. + sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), [])) def _update_nulls_to_default(self, params, field): "Subcommand of alter_column that updates nulls to default value (overrideable)" @@ -835,8 +834,13 @@ class DatabaseOperations(object): # If there is just one column in the index, use a default algorithm from Django if len(column_names) == 1 and not suffix: + try: + _hash = self._digest([column_names[0]]) + except TypeError: + # Django < 1.5 backward compatibility. + _hash = self._digest(column_names[0]) return self.shorten_name( - '%s_%s' % (table_name, self._digest(column_names[0])) + '%s_%s' % (table_name, _hash), ) # Else generate the name for the index by South diff --git a/awx/lib/site-packages/south/db/oracle.py b/awx/lib/site-packages/south/db/oracle.py index 6e002945ff..cb4148b492 100644 --- a/awx/lib/site-packages/south/db/oracle.py +++ b/awx/lib/site-packages/south/db/oracle.py @@ -86,9 +86,12 @@ class DatabaseOperations(generic.DatabaseOperations): for field_name, field in fields:
+ field = self._field_sanity(field)
+ # avoid default values in CREATE TABLE statements (#925)
field._suppress_default = True
+ col = self.column_sql(table_name, field_name, field)
if not col:
continue
@@ -159,12 +162,12 @@ END; 'nullity': 'NOT NULL',
'default': 'NULL'
}
- if field.null:
+ if field.null: params['nullity'] = 'NULL'
sql_templates = [
- (self.alter_string_set_type, params),
- (self.alter_string_set_default, params),
+ (self.alter_string_set_type, params, []),
+ (self.alter_string_set_default, params, []),
]
if not field.null and field.has_default():
# Use default for rows that had nulls. To support the case where
@@ -177,8 +180,8 @@ END; p.update(kw)
return p
sql_templates[:0] = [
- (self.alter_string_set_type, change_params(nullity='NULL')),
- (self.alter_string_update_nulls_to_default, change_params(default=self._default_value_workaround(field.get_default()))),
+ (self.alter_string_set_type, change_params(nullity='NULL'),[]),
+ (self.alter_string_update_nulls_to_default, change_params(default="%s"), [field.get_default()]),
]
@@ -191,9 +194,9 @@ END; 'constraint': self.quote_name(constraint),
})
- for sql_template, params in sql_templates:
+ for sql_template, params, args in sql_templates:
try:
- self.execute(sql_template % params, print_all_errors=False)
+ self.execute(sql_template % params, args, print_all_errors=False)
except DatabaseError as exc:
description = str(exc)
# Oracle complains if a column is already NULL/NOT NULL
@@ -250,6 +253,7 @@ END; @generic.invalidate_table_constraints
def add_column(self, table_name, name, field, keep_default=False):
+ field = self._field_sanity(field)
sql = self.column_sql(table_name, name, field)
sql = self.adj_column_sql(sql)
@@ -288,7 +292,11 @@ END; """
if isinstance(field, models.BooleanField) and field.has_default():
field.default = int(field.to_python(field.get_default()))
+ # On Oracle, empty strings are null + if isinstance(field, (models.CharField, models.TextField)): + field.null = field.empty_strings_allowed
return field
+ def _default_value_workaround(self, value):
from datetime import date,time,datetime
diff --git a/awx/lib/site-packages/south/db/sql_server/pyodbc.py b/awx/lib/site-packages/south/db/sql_server/pyodbc.py index 1b200ad13b..b725ec0da6 100644 --- a/awx/lib/site-packages/south/db/sql_server/pyodbc.py +++ b/awx/lib/site-packages/south/db/sql_server/pyodbc.py @@ -242,19 +242,15 @@ class DatabaseOperations(generic.DatabaseOperations): def _alter_set_defaults(self, field, name, params, sqls): "Subcommand of alter_column that sets default values (overrideable)" - # First drop the current default if one exists + # Historically, we used to set defaults here. + # But since South 0.8, we don't ever set defaults on alter-column -- we only + # use database-level defaults as scaffolding when adding columns. + # However, we still sometimes need to remove defaults in alter-column. table_name = self.quote_name(params['table_name']) drop_default = self.drop_column_default_sql(table_name, name) if drop_default: sqls.append((drop_default, [])) - # Next, set any default - - if field.has_default(): - default = field.get_default() - literal = self._value_to_unquoted_literal(field, default) - sqls.append(('ADD DEFAULT %s for %s' % (self._quote_string(literal), self.quote_name(name),), [])) - def _value_to_unquoted_literal(self, field, value): # Start with the field's own translation conn = self._get_connection() @@ -432,6 +428,7 @@ class DatabaseOperations(generic.DatabaseOperations): INNER JOIN sys.schemas s ON t.schema_id = s.schema_id INNER JOIN sys.indexes i ON i.object_id = t.object_id INNER JOIN sys.index_columns ic ON ic.object_id = t.object_id + AND ic.index_id = i.index_id INNER JOIN sys.columns c ON c.object_id = t.object_id AND ic.column_id = c.column_id WHERE i.is_unique=0 AND i.is_primary_key=0 AND i.is_unique_constraint=0 diff --git a/awx/lib/site-packages/south/db/sqlite3.py b/awx/lib/site-packages/south/db/sqlite3.py index db45511456..ecdb2ce6d2 100644 --- a/awx/lib/site-packages/south/db/sqlite3.py +++ b/awx/lib/site-packages/south/db/sqlite3.py @@ -31,7 +31,7 @@ class DatabaseOperations(generic.DatabaseOperations): field_default = None
if not getattr(field, '_suppress_default', False):
default = field.get_default()
- if default is not None and default!='':
+ if default is not None:
field_default = "'%s'" % field.get_db_prep_save(default, connection=self._get_connection())
field._suppress_default = True
self._remake_table(table_name, added={
@@ -136,7 +136,7 @@ class DatabaseOperations(generic.DatabaseOperations): continue
src_fields_new.append(self.quote_name(field))
for field, (_,default) in added.items():
- if default is not None and default!='':
+ if default is not None:
field = self.quote_name(field)
src_fields_new.append("%s as %s" % (default, field))
dst_fields_new.append(field)
@@ -263,10 +263,3 @@ class DatabaseOperations(generic.DatabaseOperations): # No cascades on deletes
def delete_table(self, table_name, cascade=True):
generic.DatabaseOperations.delete_table(self, table_name, False)
-
- def _default_value_workaround(self, default):
- if default == True:
- default = 1
- elif default == False:
- default = 0
- return default
diff --git a/awx/lib/site-packages/south/hacks/django_1_0.py b/awx/lib/site-packages/south/hacks/django_1_0.py index 00d0a8bab3..e4a60c66a9 100644 --- a/awx/lib/site-packages/south/hacks/django_1_0.py +++ b/awx/lib/site-packages/south/hacks/django_1_0.py @@ -2,6 +2,7 @@ Hacks for the Django 1.0/1.0.2 releases. """ +import django from django.conf import settings from django.db.backends.creation import BaseDatabaseCreation from django.db.models.loading import cache @@ -51,7 +52,7 @@ class Hacks: Used to repopulate AppCache after fiddling with INSTALLED_APPS. """ cache.loaded = False - cache.handled = {} + cache.handled = set() if django.VERSION >= (1, 6) else {} cache.postponed = [] cache.app_store = SortedDict() cache.app_models = SortedDict() diff --git a/awx/lib/site-packages/south/management/commands/datamigration.py b/awx/lib/site-packages/south/management/commands/datamigration.py index 08c1d0891c..9254e56c15 100644 --- a/awx/lib/site-packages/south/management/commands/datamigration.py +++ b/awx/lib/site-packages/south/management/commands/datamigration.py @@ -34,6 +34,8 @@ class Command(BaseCommand): usage_str = "Usage: ./manage.py datamigration appname migrationname [--stdout] [--freeze appname]" def handle(self, app=None, name="", freeze_list=None, stdout=False, verbosity=1, **options): + + verbosity = int(verbosity) # Any supposed lists that are None become empty lists freeze_list = freeze_list or [] @@ -46,13 +48,19 @@ class Command(BaseCommand): if re.search('[^_\w]', name) and name != "-": self.error("Migration names should contain only alphanumeric characters and underscores.") - # if not name, there's an error + # If not name, there's an error if not name: - self.error("You must provide a name for this migration\n" + self.usage_str) + self.error("You must provide a name for this migration.\n" + self.usage_str) if not app: self.error("You must provide an app to create a migration for.\n" + self.usage_str) - + + # Ensure that verbosity is not a string (Python 3) + try: + verbosity = int(verbosity) + except ValueError: + self.error("Verbosity must be an number.\n" + self.usage_str) + # Get the Migrations for this app (creating the migrations dir if needed) migrations = Migrations(app, force_creation=True, verbose_creation=verbosity > 0) @@ -63,7 +71,7 @@ class Command(BaseCommand): apps_to_freeze = self.calc_frozen_apps(migrations, freeze_list) # So, what's in this file, then? - file_contents = MIGRATION_TEMPLATE % { + file_contents = self.get_migration_template() % { "frozen_models": freezer.freeze_apps_to_string(apps_to_freeze), "complete_apps": apps_to_freeze and "complete_apps = [%s]" % (", ".join(map(repr, apps_to_freeze))) or "" } @@ -103,6 +111,9 @@ class Command(BaseCommand): print(message, file=sys.stderr) sys.exit(code) + def get_migration_template(self): + return MIGRATION_TEMPLATE + MIGRATION_TEMPLATE = """# -*- coding: utf-8 -*- import datetime diff --git a/awx/lib/site-packages/south/management/commands/schemamigration.py b/awx/lib/site-packages/south/management/commands/schemamigration.py index e29fc620b6..514f4a79d9 100644 --- a/awx/lib/site-packages/south/management/commands/schemamigration.py +++ b/awx/lib/site-packages/south/management/commands/schemamigration.py @@ -168,7 +168,7 @@ class Command(DataCommand): apps_to_freeze = self.calc_frozen_apps(migrations, freeze_list) # So, what's in this file, then? - file_contents = MIGRATION_TEMPLATE % { + file_contents = self.get_migration_template() % { "forwards": "\n".join(forwards_actions or [" pass"]), "backwards": "\n".join(backwards_actions or [" pass"]), "frozen_models": freezer.freeze_apps_to_string(apps_to_freeze), @@ -205,6 +205,9 @@ class Command(DataCommand): else: print("%s %s. You can now apply this migration with: ./manage.py migrate %s" % (verb, new_filename, app), file=sys.stderr) + def get_migration_template(self): + return MIGRATION_TEMPLATE + MIGRATION_TEMPLATE = """# -*- coding: utf-8 -*- import datetime diff --git a/awx/lib/site-packages/south/management/commands/syncdb.py b/awx/lib/site-packages/south/management/commands/syncdb.py index 702085b194..17fc22cbfc 100644 --- a/awx/lib/site-packages/south/management/commands/syncdb.py +++ b/awx/lib/site-packages/south/management/commands/syncdb.py @@ -98,6 +98,8 @@ class Command(NoArgsCommand): if options.get('migrate', True): if verbosity: print("Migrating...") + # convert from store_true to store_false + options['no_initial_data'] = not options.get('load_initial_data', True) management.call_command('migrate', **options) # Be obvious about what we did diff --git a/awx/lib/site-packages/south/migration/migrators.py b/awx/lib/site-packages/south/migration/migrators.py index 1be895dcf3..98b3489165 100644 --- a/awx/lib/site-packages/south/migration/migrators.py +++ b/awx/lib/site-packages/south/migration/migrators.py @@ -5,10 +5,6 @@ import datetime import inspect import sys import traceback -try: - from cStringIO import StringIO # python 2 -except ImportError: - from io import StringIO # python 3 from django.core.management import call_command from django.core.management.commands import loaddata @@ -19,6 +15,7 @@ from south import exceptions from south.db import DEFAULT_DB_ALIAS from south.models import MigrationHistory from south.signals import ran_migration +from south.utils.py3 import StringIO class Migrator(object): diff --git a/awx/lib/site-packages/south/orm.py b/awx/lib/site-packages/south/orm.py index 1e5d56ed79..8d46ee7194 100644 --- a/awx/lib/site-packages/south/orm.py +++ b/awx/lib/site-packages/south/orm.py @@ -181,12 +181,16 @@ class _FakeORM(object): "Evaluates the given code in the context of the migration file." # Drag in the migration module's locals (hopefully including models.py) - fake_locals = dict(inspect.getmodule(self.cls).__dict__) - - # Remove all models from that (i.e. from modern models.py), to stop pollution - for key, value in fake_locals.items(): - if isinstance(value, type) and issubclass(value, models.Model) and hasattr(value, "_meta"): - del fake_locals[key] + # excluding all models from that (i.e. from modern models.py), to stop pollution + fake_locals = dict( + (key, value) + for key, value in inspect.getmodule(self.cls).__dict__.items() + if not ( + isinstance(value, type) + and issubclass(value, models.Model) + and hasattr(value, "_meta") + ) + ) # We add our models into the locals for the eval fake_locals.update(dict([ diff --git a/awx/lib/site-packages/south/tests/db.py b/awx/lib/site-packages/south/tests/db.py index 90f62fc0be..353677ebed 100644 --- a/awx/lib/site-packages/south/tests/db.py +++ b/awx/lib/site-packages/south/tests/db.py @@ -360,7 +360,57 @@ class TestOperations(unittest.TestCase): db.execute("INSERT INTO test_altercd (eggs) values (12)") null = db.execute("SELECT spam FROM test_altercd")[0][0] self.assertFalse(null, "Default for char field was installed into database") + + # Change again to a column with default and not null + db.alter_column("test_altercd", "spam", models.CharField(max_length=30, default="loof", null=False)) + # Assert the default is not in the database + if 'oracle' in db.backend_name: + # Oracle special treatment -- nulls are always allowed in char columns, so + # inserting doesn't raise an integrity error; so we check again as above + db.execute("DELETE FROM test_altercd") + db.execute("INSERT INTO test_altercd (eggs) values (12)") + null = db.execute("SELECT spam FROM test_altercd")[0][0] + self.assertFalse(null, "Default for char field was installed into database") + else: + # For other backends, insert should now just fail + self.assertRaises(IntegrityError, + db.execute, "INSERT INTO test_altercd (eggs) values (12)") + + @skipIf('oracle' in db.backend_name, "Oracle does not differentiate empty trings from null") + def test_default_empty_string(self): + """ + Test altering column defaults with char fields + """ + db.create_table("test_cd_empty", [ + ('spam', models.CharField(max_length=30, default='')), + ('eggs', models.CharField(max_length=30)), + ]) + # Create a record + db.execute("INSERT INTO test_cd_empty (spam, eggs) values ('1','2')") + # Add a column + db.add_column("test_cd_empty", "ham", models.CharField(max_length=30, default='')) + empty = db.execute("SELECT ham FROM test_cd_empty")[0][0] + self.assertEquals(empty, "", "Empty Default for char field isn't empty string") + + @skipUnless('oracle' in db.backend_name, "Oracle does not differentiate empty trings from null") + def test_oracle_strings_null(self): + """ + Test that under Oracle, CherFields are created as null even when specified not-null, + because otherwise they would not be able to hold empty strings (which Oracle equates + with nulls). + Verify fix of #1269. + """ + db.create_table("test_ora_char_nulls", [ + ('spam', models.CharField(max_length=30, null=True)), + ('eggs', models.CharField(max_length=30)), + ]) + db.add_column("test_ora_char_nulls", "ham", models.CharField(max_length=30)) + db.alter_column("test_ora_char_nulls", "spam", models.CharField(max_length=30, null=False)) + # So, by the look of it, we should now have three not-null columns + db.execute("INSERT INTO test_ora_char_nulls VALUES (NULL, NULL, NULL)") + + def test_mysql_defaults(self): """ Test MySQL default handling for BLOB and TEXT. @@ -568,8 +618,8 @@ class TestOperations(unittest.TestCase): db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (1, 42)") except: self.fail("Looks like multi-field unique constraint applied to only one field.") - db.start_transaction() db.rollback_transaction() + db.start_transaction() try: db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (0, 43)") except: @@ -581,17 +631,17 @@ class TestOperations(unittest.TestCase): except: pass else: - self.fail("Could insert the same integer twice into a unique field.") + self.fail("Could insert the same pair twice into unique-together fields.") db.rollback_transaction() # Altering one column should not drop or modify multi-column constraint - db.alter_column("test_alter_unique2", "eggs", models.CharField(max_length=10)) + db.alter_column("test_alter_unique2", "eggs", models.PositiveIntegerField()) db.start_transaction() try: db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (1, 42)") except: self.fail("Altering one column broken multi-column unique constraint.") - db.start_transaction() db.rollback_transaction() + db.start_transaction() try: db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (0, 43)") except: @@ -603,7 +653,7 @@ class TestOperations(unittest.TestCase): except: pass else: - self.fail("Could insert the same integer twice into a unique field after alter_column with unique=False.") + self.fail("Could insert the same pair twice into unique-together fields after alter_column with unique=False.") db.rollback_transaction() db.delete_table("test_alter_unique2") db.start_transaction() @@ -811,10 +861,39 @@ class TestOperations(unittest.TestCase): db.add_column("test_fk", 'foreik', models.ForeignKey(User, null=True)) db.execute_deferred_sql() - # Make the FK null + # Make the FK not null db.alter_column("test_fk", "foreik_id", models.ForeignKey(User)) db.execute_deferred_sql() + def test_make_foreign_key_null(self): + # Table for FK to target + User = db.mock_model(model_name='User', db_table='auth_user', db_tablespace='', pk_field_name='id', pk_field_type=models.AutoField, pk_field_args=[], pk_field_kwargs={}) + # Table with no foreign key + db.create_table("test_make_fk_null", [ + ('eggs', models.IntegerField()), + ('foreik', models.ForeignKey(User)) + ]) + db.execute_deferred_sql() + + # Make the FK null + db.alter_column("test_make_fk_null", "foreik_id", models.ForeignKey(User, null=True)) + db.execute_deferred_sql() + + def test_alter_double_indexed_column(self): + # Table for FK to target + User = db.mock_model(model_name='User', db_table='auth_user', db_tablespace='', pk_field_name='id', pk_field_type=models.AutoField, pk_field_args=[], pk_field_kwargs={}) + # Table with no foreign key + db.create_table("test_2indexed", [ + ('eggs', models.IntegerField()), + ('foreik', models.ForeignKey(User)) + ]) + db.create_unique("test_2indexed", ["eggs", "foreik_id"]) + db.execute_deferred_sql() + + # Make the FK null + db.alter_column("test_2indexed", "foreik_id", models.ForeignKey(User, null=True)) + db.execute_deferred_sql() + class TestCacheGeneric(unittest.TestCase): base_ops_cls = generic.DatabaseOperations def setUp(self): diff --git a/awx/lib/site-packages/south/utils/__init__.py b/awx/lib/site-packages/south/utils/__init__.py index c3c5191633..8d7297ea5d 100644 --- a/awx/lib/site-packages/south/utils/__init__.py +++ b/awx/lib/site-packages/south/utils/__init__.py @@ -5,13 +5,13 @@ Generally helpful utility functions. def _ask_for_it_by_name(name): "Returns an object referenced by absolute path." - bits = name.split(".") + bits = str(name).split(".") ## what if there is no absolute reference? - if len(bits)>1: + if len(bits) > 1: modulename = ".".join(bits[:-1]) else: - modulename=bits[0] + modulename = bits[0] module = __import__(modulename, {}, {}, bits[-1]) diff --git a/awx/lib/site-packages/south/utils/py3.py b/awx/lib/site-packages/south/utils/py3.py index 9c5baaded0..732e9043a8 100644 --- a/awx/lib/site-packages/south/utils/py3.py +++ b/awx/lib/site-packages/south/utils/py3.py @@ -11,11 +11,18 @@ if PY3: text_type = str raw_input = input + import io + StringIO = io.StringIO + else: string_types = basestring, text_type = unicode raw_input = raw_input + import cStringIO + StringIO = cStringIO.StringIO + + def with_metaclass(meta, base=object): """Create a base class with a metaclass.""" return meta("NewBase", (base,), {}) diff --git a/awx/lib/site-packages/taggit/__init__.py b/awx/lib/site-packages/taggit/__init__.py index 4c4993466d..ca04bbe69d 100644 --- a/awx/lib/site-packages/taggit/__init__.py +++ b/awx/lib/site-packages/taggit/__init__.py @@ -1 +1 @@ -VERSION = (0, 10, 0, 'alpha', 1) +VERSION = (0, 10, 0) diff --git a/awx/lib/site-packages/taggit/admin.py b/awx/lib/site-packages/taggit/admin.py index 6c012d6fe3..0498c9ddf4 100644 --- a/awx/lib/site-packages/taggit/admin.py +++ b/awx/lib/site-packages/taggit/admin.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + from django.contrib import admin from taggit.models import Tag, TaggedItem diff --git a/awx/lib/site-packages/taggit/forms.py b/awx/lib/site-packages/taggit/forms.py index e0198bd933..cb372f22a0 100644 --- a/awx/lib/site-packages/taggit/forms.py +++ b/awx/lib/site-packages/taggit/forms.py @@ -1,12 +1,15 @@ +from __future__ import unicode_literals + from django import forms from django.utils.translation import ugettext as _ +from django.utils import six from taggit.utils import parse_tags, edit_string_for_tags class TagWidget(forms.TextInput): def render(self, name, value, attrs=None): - if value is not None and not isinstance(value, basestring): + if value is not None and not isinstance(value, six.string_types): value = edit_string_for_tags([o.tag for o in value.select_related("tag")]) return super(TagWidget, self).render(name, value, attrs) diff --git a/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.mo b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.mo Binary files differnew file mode 100644 index 0000000000..9ce13fb2c2 --- /dev/null +++ b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.mo diff --git a/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.po b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.po new file mode 100644 index 0000000000..13262e1c78 --- /dev/null +++ b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.po @@ -0,0 +1,64 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2013-08-01 16:52+0200\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" +"Language-Team: LANGUAGE <LL@li.org>\n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n==1) ? 0 : (n>=2 && n<=4) ? 1 : 2;\n" + +#: forms.py:24 +msgid "Please provide a comma-separated list of tags." +msgstr "Vložte čárkami oddělený seznam tagů" + +#: managers.py:59 models.py:59 +msgid "Tags" +msgstr "Tagy" + +#: managers.py:60 +msgid "A comma-separated list of tags." +msgstr "Čárkami oddělený seznam tagů" + +#: models.py:15 +msgid "Name" +msgstr "Jméno" + +#: models.py:16 +msgid "Slug" +msgstr "Slug" + +#: models.py:58 +msgid "Tag" +msgstr "Tag" + +#: models.py:65 +#, python-format +msgid "%(object)s tagged with %(tag)s" +msgstr "%(object)s označen tagem %(tag)s" + +#: models.py:112 +msgid "Object id" +msgstr "ID objektu" + +#: models.py:115 +msgid "Content type" +msgstr "Typ obsahu" + +#: models.py:158 +msgid "Tagged Item" +msgstr "Tagem označená položka" + +#: models.py:159 +msgid "Tagged Items" +msgstr "Tagy označené položky" diff --git a/awx/lib/site-packages/taggit/managers.py b/awx/lib/site-packages/taggit/managers.py index ca1700fd96..56fec6064c 100644 --- a/awx/lib/site-packages/taggit/managers.py +++ b/awx/lib/site-packages/taggit/managers.py @@ -1,57 +1,78 @@ +from __future__ import unicode_literals + +from django import VERSION from django.contrib.contenttypes.generic import GenericRelation from django.contrib.contenttypes.models import ContentType from django.db import models +from django.db.models.fields import Field from django.db.models.fields.related import ManyToManyRel, RelatedField, add_lazy_relation from django.db.models.related import RelatedObject from django.utils.text import capfirst from django.utils.translation import ugettext_lazy as _ +from django.utils import six + +try: + from django.db.models.related import PathInfo +except ImportError: + pass # PathInfo is not used on Django < 1.6 from taggit.forms import TagField from taggit.models import TaggedItem, GenericTaggedItemBase from taggit.utils import require_instance_manager -try: - all -except NameError: - # 2.4 compat - try: - from django.utils.itercompat import all - except ImportError: - # 1.1.X compat - def all(iterable): - for item in iterable: - if not item: - return False - return True +def _model_name(model): + if VERSION < (1, 7): + return model._meta.module_name + else: + return model._meta.model_name class TaggableRel(ManyToManyRel): - def __init__(self): + def __init__(self, field): self.related_name = None self.limit_choices_to = {} self.symmetrical = True self.multiple = True self.through = None + self.field = field + + def get_joining_columns(self): + return self.field.get_reverse_joining_columns() + + def get_extra_restriction(self, where_class, alias, related_alias): + return self.field.get_extra_restriction(where_class, related_alias, alias) + +class ExtraJoinRestriction(object): + """ + An extra restriction used for contenttype restriction in joins. + """ + def __init__(self, alias, col, content_types): + self.alias = alias + self.col = col + self.content_types = content_types -class TaggableManager(RelatedField): + def as_sql(self, qn, connection): + if len(self.content_types) == 1: + extra_where = "%s.%s = %%s" % (qn(self.alias), qn(self.col)) + params = [self.content_types[0]] + else: + extra_where = "%s.%s IN (%s)" % (qn(self.alias), qn(self.col), + ','.join(['%s'] * len(self.content_types))) + params = self.content_types + return extra_where, params + + def relabel_aliases(self, change_map): + self.alias = change_map.get(self.alias, self.alias) + + +class TaggableManager(RelatedField, Field): def __init__(self, verbose_name=_("Tags"), help_text=_("A comma-separated list of tags."), through=None, blank=False): + Field.__init__(self, verbose_name=verbose_name, help_text=help_text, blank=blank) self.through = through or TaggedItem - self.rel = TaggableRel() - self.verbose_name = verbose_name - self.help_text = help_text - self.blank = blank - self.editable = True - self.unique = False - self.creates_table = False - self.db_column = None - self.choices = None - self.serialize = False - self.null = True - self.creation_counter = models.Field.creation_counter - models.Field.creation_counter += 1 + self.rel = TaggableRel(self) def __get__(self, instance, model): if instance is not None and instance.pk is None: @@ -63,12 +84,15 @@ class TaggableManager(RelatedField): return manager def contribute_to_class(self, cls, name): - self.name = self.column = name + if VERSION < (1, 7): + self.name = self.column = name + else: + self.set_attributes_from_name(name) self.model = cls cls._meta.add_field(self) setattr(cls, name, self) if not cls._meta.abstract: - if isinstance(self.through, basestring): + if isinstance(self.through, six.string_types): def resolve_related_class(field, model, cls): self.through = model self.post_through_setup(cls) @@ -78,7 +102,16 @@ class TaggableManager(RelatedField): else: self.post_through_setup(cls) + def __lt__(self, other): + """ + Required contribute_to_class as Django uses bisect + for ordered class contribution and bisect requires + a orderable type in py3. + """ + return False + def post_through_setup(self, cls): + self.related = RelatedObject(cls, self.model, self) self.use_gfk = ( self.through is None or issubclass(self.through, GenericTaggedItemBase) ) @@ -106,11 +139,14 @@ class TaggableManager(RelatedField): return self.through.objects.none() def related_query_name(self): - return self.model._meta.module_name + return _model_name(self.model) def m2m_reverse_name(self): return self.through._meta.get_field_by_name("tag")[0].column + def m2m_reverse_field_name(self): + return self.through._meta.get_field_by_name("tag")[0].name + def m2m_target_field_name(self): return self.model._meta.pk.name @@ -128,17 +164,148 @@ class TaggableManager(RelatedField): def m2m_db_table(self): return self.through._meta.db_table + def bulk_related_objects(self, new_objs, using): + return [] + def extra_filters(self, pieces, pos, negate): if negate or not self.use_gfk: return [] prefix = "__".join(["tagged_items"] + pieces[:pos-2]) - cts = map(ContentType.objects.get_for_model, _get_subclasses(self.model)) + get = ContentType.objects.get_for_model + cts = [get(obj) for obj in _get_subclasses(self.model)] if len(cts) == 1: return [("%s__content_type" % prefix, cts[0])] return [("%s__content_type__in" % prefix, cts)] - def bulk_related_objects(self, new_objs, using): - return [] + def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias): + model_name = _model_name(self.through) + if rhs_alias == '%s_%s' % (self.through._meta.app_label, model_name): + alias_to_join = rhs_alias + else: + alias_to_join = lhs_alias + extra_col = self.through._meta.get_field_by_name('content_type')[0].column + content_type_ids = [ContentType.objects.get_for_model(subclass).pk for subclass in _get_subclasses(self.model)] + if len(content_type_ids) == 1: + content_type_id = content_type_ids[0] + extra_where = " AND %s.%s = %%s" % (qn(alias_to_join), qn(extra_col)) + params = [content_type_id] + else: + extra_where = " AND %s.%s IN (%s)" % (qn(alias_to_join), qn(extra_col), ','.join(['%s']*len(content_type_ids))) + params = content_type_ids + return extra_where, params + + def _get_mm_case_path_info(self, direct=False): + pathinfos = [] + linkfield1 = self.through._meta.get_field_by_name('content_object')[0] + linkfield2 = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0] + if direct: + join1infos, _, _, _ = linkfield1.get_reverse_path_info() + join2infos, opts, target, final = linkfield2.get_path_info() + else: + join1infos, _, _, _ = linkfield2.get_reverse_path_info() + join2infos, opts, target, final = linkfield1.get_path_info() + pathinfos.extend(join1infos) + pathinfos.extend(join2infos) + return pathinfos, opts, target, final + + def _get_gfk_case_path_info(self, direct=False): + pathinfos = [] + from_field = self.model._meta.pk + opts = self.through._meta + object_id_field = opts.get_field_by_name('object_id')[0] + linkfield = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0] + if direct: + join1infos = [PathInfo(from_field, object_id_field, self.model._meta, opts, self, True, False)] + join2infos, opts, target, final = linkfield.get_path_info() + else: + join1infos, _, _, _ = linkfield.get_reverse_path_info() + join2infos = [PathInfo(object_id_field, from_field, opts, self.model._meta, self, True, False)] + target = from_field + final = self + opts = self.model._meta + + pathinfos.extend(join1infos) + pathinfos.extend(join2infos) + return pathinfos, opts, target, final + + def get_path_info(self): + if self.use_gfk: + return self._get_gfk_case_path_info(direct=True) + else: + return self._get_mm_case_path_info(direct=True) + + def get_reverse_path_info(self): + if self.use_gfk: + return self._get_gfk_case_path_info(direct=False) + else: + return self._get_mm_case_path_info(direct=False) + + # This and all the methods till the end of class are only used in django >= 1.6 + def _get_mm_case_path_info(self, direct=False): + pathinfos = [] + linkfield1 = self.through._meta.get_field_by_name('content_object')[0] + linkfield2 = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0] + if direct: + join1infos = linkfield1.get_reverse_path_info() + join2infos = linkfield2.get_path_info() + else: + join1infos = linkfield2.get_reverse_path_info() + join2infos = linkfield1.get_path_info() + pathinfos.extend(join1infos) + pathinfos.extend(join2infos) + return pathinfos + + def _get_gfk_case_path_info(self, direct=False): + pathinfos = [] + from_field = self.model._meta.pk + opts = self.through._meta + object_id_field = opts.get_field_by_name('object_id')[0] + linkfield = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0] + if direct: + join1infos = [PathInfo(self.model._meta, opts, [from_field], self.rel, True, False)] + join2infos = linkfield.get_path_info() + else: + join1infos = linkfield.get_reverse_path_info() + join2infos = [PathInfo(opts, self.model._meta, [object_id_field], self, True, False)] + pathinfos.extend(join1infos) + pathinfos.extend(join2infos) + return pathinfos + + def get_path_info(self): + if self.use_gfk: + return self._get_gfk_case_path_info(direct=True) + else: + return self._get_mm_case_path_info(direct=True) + + def get_reverse_path_info(self): + if self.use_gfk: + return self._get_gfk_case_path_info(direct=False) + else: + return self._get_mm_case_path_info(direct=False) + + def get_joining_columns(self, reverse_join=False): + if reverse_join: + return (("id", "object_id"),) + else: + return (("object_id", "id"),) + + def get_extra_restriction(self, where_class, alias, related_alias): + extra_col = self.through._meta.get_field_by_name('content_type')[0].column + content_type_ids = [ContentType.objects.get_for_model(subclass).pk + for subclass in _get_subclasses(self.model)] + return ExtraJoinRestriction(related_alias, extra_col, content_type_ids) + + def get_reverse_joining_columns(self): + return self.get_joining_columns(reverse_join=True) + + @property + def related_fields(self): + return [(self.through._meta.get_field_by_name('object_id')[0], + self.model._meta.pk)] + + @property + def foreign_related_fields(self): + return [self.related_fields[0][1]] class _TaggableManager(models.Manager): @@ -150,6 +317,9 @@ class _TaggableManager(models.Manager): def get_query_set(self): return self.through.tags_for(self.model, self.instance) + # Django 1.6 renamed this + get_queryset = get_query_set + def _lookup_kwargs(self): return self.through.lookup_kwargs(self.instance) @@ -175,6 +345,14 @@ class _TaggableManager(models.Manager): self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs()) @require_instance_manager + def names(self): + return self.get_query_set().values_list('name', flat=True) + + @require_instance_manager + def slugs(self): + return self.get_query_set().values_list('slug', flat=True) + + @require_instance_manager def set(self, *tags): self.clear() self.add(*tags) @@ -197,7 +375,7 @@ class _TaggableManager(models.Manager): def similar_objects(self): lookup_kwargs = self._lookup_kwargs() lookup_keys = sorted(lookup_kwargs) - qs = self.through.objects.values(*lookup_kwargs.keys()) + qs = self.through.objects.values(*six.iterkeys(lookup_kwargs)) qs = qs.annotate(n=models.Count('pk')) qs = qs.exclude(**lookup_kwargs) qs = qs.filter(tag__in=self.all()) @@ -220,7 +398,7 @@ class _TaggableManager(models.Manager): preload.setdefault(result['content_type'], set()) preload[result["content_type"]].add(result["object_id"]) - for ct, obj_ids in preload.iteritems(): + for ct, obj_ids in preload.items(): ct = ContentType.objects.get_for_id(ct) for obj in ct.model_class()._default_manager.filter(pk__in=obj_ids): items[(ct.pk, obj.pk)] = obj @@ -243,3 +421,11 @@ def _get_subclasses(model): getattr(field.field.rel, "parent_link", None)): subclasses.extend(_get_subclasses(field.model)) return subclasses + + +# `total_ordering` does not exist in Django 1.4, as such +# we special case this import to be py3k specific which +# is not supported by Django 1.4 +if six.PY3: + from django.utils.functional import total_ordering + TaggableManager = total_ordering(TaggableManager) diff --git a/awx/lib/site-packages/taggit/migrations/0001_initial.py b/awx/lib/site-packages/taggit/migrations/0001_initial.py index 666ca6199e..6808f38bca 100644 --- a/awx/lib/site-packages/taggit/migrations/0001_initial.py +++ b/awx/lib/site-packages/taggit/migrations/0001_initial.py @@ -9,52 +9,52 @@ class Migration(SchemaMigration): def forwards(self, orm): # Adding model 'Tag' - db.create_table(u'taggit_tag', ( - (u'id', self.gf('django.db.models.fields.AutoField')(primary_key=True)), + db.create_table('taggit_tag', ( + ('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)), ('name', self.gf('django.db.models.fields.CharField')(max_length=100)), ('slug', self.gf('django.db.models.fields.SlugField')(unique=True, max_length=100)), )) - db.send_create_signal(u'taggit', ['Tag']) + db.send_create_signal('taggit', ['Tag']) # Adding model 'TaggedItem' - db.create_table(u'taggit_taggeditem', ( - (u'id', self.gf('django.db.models.fields.AutoField')(primary_key=True)), - ('tag', self.gf('django.db.models.fields.related.ForeignKey')(related_name=u'taggit_taggeditem_items', to=orm['taggit.Tag'])), + db.create_table('taggit_taggeditem', ( + ('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)), + ('tag', self.gf('django.db.models.fields.related.ForeignKey')(related_name='taggit_taggeditem_items', to=orm['taggit.Tag'])), ('object_id', self.gf('django.db.models.fields.IntegerField')(db_index=True)), - ('content_type', self.gf('django.db.models.fields.related.ForeignKey')(related_name=u'taggit_taggeditem_tagged_items', to=orm['contenttypes.ContentType'])), + ('content_type', self.gf('django.db.models.fields.related.ForeignKey')(related_name='taggit_taggeditem_tagged_items', to=orm['contenttypes.ContentType'])), )) - db.send_create_signal(u'taggit', ['TaggedItem']) + db.send_create_signal('taggit', ['TaggedItem']) def backwards(self, orm): # Deleting model 'Tag' - db.delete_table(u'taggit_tag') + db.delete_table('taggit_tag') # Deleting model 'TaggedItem' - db.delete_table(u'taggit_taggeditem') + db.delete_table('taggit_taggeditem') models = { - u'contenttypes.contenttype': { + 'contenttypes.contenttype': { 'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"}, 'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}), - u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), 'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}), 'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}) }, - u'taggit.tag': { + 'taggit.tag': { 'Meta': {'object_name': 'Tag'}, - u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), 'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}), 'slug': ('django.db.models.fields.SlugField', [], {'unique': 'True', 'max_length': '100'}) }, - u'taggit.taggeditem': { + 'taggit.taggeditem': { 'Meta': {'object_name': 'TaggedItem'}, - 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_tagged_items'", 'to': u"orm['contenttypes.ContentType']"}), - u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_tagged_items'", 'to': "orm['contenttypes.ContentType']"}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), 'object_id': ('django.db.models.fields.IntegerField', [], {'db_index': 'True'}), - 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_items'", 'to': u"orm['taggit.Tag']"}) + 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_items'", 'to': "orm['taggit.Tag']"}) } } - complete_apps = ['taggit']
\ No newline at end of file + complete_apps = ['taggit'] diff --git a/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py b/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py index e5eb033b0a..d68ea10164 100644 --- a/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py +++ b/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py @@ -9,35 +9,35 @@ class Migration(SchemaMigration): def forwards(self, orm): # Adding unique constraint on 'Tag', fields ['name'] - db.create_unique(u'taggit_tag', ['name']) + db.create_unique('taggit_tag', ['name']) def backwards(self, orm): # Removing unique constraint on 'Tag', fields ['name'] - db.delete_unique(u'taggit_tag', ['name']) + db.delete_unique('taggit_tag', ['name']) models = { - u'contenttypes.contenttype': { + 'contenttypes.contenttype': { 'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"}, 'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}), - u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), 'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}), 'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}) }, - u'taggit.tag': { + 'taggit.tag': { 'Meta': {'object_name': 'Tag'}, - u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), 'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '100'}), 'slug': ('django.db.models.fields.SlugField', [], {'unique': 'True', 'max_length': '100'}) }, - u'taggit.taggeditem': { + 'taggit.taggeditem': { 'Meta': {'object_name': 'TaggedItem'}, - 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_tagged_items'", 'to': u"orm['contenttypes.ContentType']"}), - u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_tagged_items'", 'to': "orm['contenttypes.ContentType']"}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), 'object_id': ('django.db.models.fields.IntegerField', [], {'db_index': 'True'}), - 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_items'", 'to': u"orm['taggit.Tag']"}) + 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_items'", 'to': "orm['taggit.Tag']"}) } } - complete_apps = ['taggit']
\ No newline at end of file + complete_apps = ['taggit'] diff --git a/awx/lib/site-packages/taggit/models.py b/awx/lib/site-packages/taggit/models.py index 581a5b1920..0335757875 100644 --- a/awx/lib/site-packages/taggit/models.py +++ b/awx/lib/site-packages/taggit/models.py @@ -1,16 +1,21 @@ +from __future__ import unicode_literals + import django from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.generic import GenericForeignKey from django.db import models, IntegrityError, transaction +from django.db.models.query import QuerySet from django.template.defaultfilters import slugify as default_slugify from django.utils.translation import ugettext_lazy as _, ugettext +from django.utils.encoding import python_2_unicode_compatible +@python_2_unicode_compatible class TagBase(models.Model): name = models.CharField(verbose_name=_('Name'), unique=True, max_length=100) slug = models.SlugField(verbose_name=_('Slug'), unique=True, max_length=100) - def __unicode__(self): + def __str__(self): return self.name class Meta: @@ -19,17 +24,14 @@ class TagBase(models.Model): def save(self, *args, **kwargs): if not self.pk and not self.slug: self.slug = self.slugify(self.name) - if django.VERSION >= (1, 2): - from django.db import router - using = kwargs.get("using") or router.db_for_write( - type(self), instance=self) - # Make sure we write to the same db for all attempted writes, - # with a multi-master setup, theoretically we could try to - # write and rollback on different DBs - kwargs["using"] = using - trans_kwargs = {"using": using} - else: - trans_kwargs = {} + from django.db import router + using = kwargs.get("using") or router.db_for_write( + type(self), instance=self) + # Make sure we write to the same db for all attempted writes, + # with a multi-master setup, theoretically we could try to + # write and rollback on different DBs + kwargs["using"] = using + trans_kwargs = {"using": using} i = 0 while True: i += 1 @@ -57,9 +59,9 @@ class Tag(TagBase): verbose_name_plural = _("Tags") - +@python_2_unicode_compatible class ItemBase(models.Model): - def __unicode__(self): + def __str__(self): return ugettext("%(object)s tagged with %(tag)s") % { "object": self.content_object, "tag": self.tag @@ -90,10 +92,7 @@ class ItemBase(models.Model): class TaggedItemBase(ItemBase): - if django.VERSION < (1, 2): - tag = models.ForeignKey(Tag, related_name="%(class)s_items") - else: - tag = models.ForeignKey(Tag, related_name="%(app_label)s_%(class)s_items") + tag = models.ForeignKey(Tag, related_name="%(app_label)s_%(class)s_items") class Meta: abstract = True @@ -111,18 +110,11 @@ class TaggedItemBase(ItemBase): class GenericTaggedItemBase(ItemBase): object_id = models.IntegerField(verbose_name=_('Object id'), db_index=True) - if django.VERSION < (1, 2): - content_type = models.ForeignKey( - ContentType, - verbose_name=_('Content type'), - related_name="%(class)s_tagged_items" - ) - else: - content_type = models.ForeignKey( - ContentType, - verbose_name=_('Content type'), - related_name="%(app_label)s_%(class)s_tagged_items" - ) + content_type = models.ForeignKey( + ContentType, + verbose_name=_('Content type'), + related_name="%(app_label)s_%(class)s_tagged_items" + ) content_object = GenericForeignKey() class Meta: @@ -137,11 +129,18 @@ class GenericTaggedItemBase(ItemBase): @classmethod def bulk_lookup_kwargs(cls, instances): - # TODO: instances[0], can we assume there are instances. - return { - "object_id__in": [instance.pk for instance in instances], - "content_type": ContentType.objects.get_for_model(instances[0]), - } + if isinstance(instances, QuerySet): + # Can do a real object_id IN (SELECT ..) query. + return { + "object_id__in": instances, + "content_type": ContentType.objects.get_for_model(instances.model), + } + else: + # TODO: instances[0], can we assume there are instances. + return { + "object_id__in": [instance.pk for instance in instances], + "content_type": ContentType.objects.get_for_model(instances[0]), + } @classmethod def tags_for(cls, model, instance=None): diff --git a/awx/lib/site-packages/taggit/tests/__init__.py b/awx/lib/site-packages/taggit/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 --- a/awx/lib/site-packages/taggit/tests/__init__.py +++ /dev/null diff --git a/awx/lib/site-packages/taggit/tests/forms.py b/awx/lib/site-packages/taggit/tests/forms.py deleted file mode 100644 index 2cdc6a8dca..0000000000 --- a/awx/lib/site-packages/taggit/tests/forms.py +++ /dev/null @@ -1,20 +0,0 @@ -from django import forms - -from taggit.tests.models import Food, DirectFood, CustomPKFood, OfficialFood - - -class FoodForm(forms.ModelForm): - class Meta: - model = Food - -class DirectFoodForm(forms.ModelForm): - class Meta: - model = DirectFood - -class CustomPKFoodForm(forms.ModelForm): - class Meta: - model = CustomPKFood - -class OfficialFoodForm(forms.ModelForm): - class Meta: - model = OfficialFood diff --git a/awx/lib/site-packages/taggit/tests/models.py b/awx/lib/site-packages/taggit/tests/models.py deleted file mode 100644 index a0e21e046f..0000000000 --- a/awx/lib/site-packages/taggit/tests/models.py +++ /dev/null @@ -1,143 +0,0 @@ -from django.db import models - -from taggit.managers import TaggableManager -from taggit.models import (TaggedItemBase, GenericTaggedItemBase, TaggedItem, - TagBase, Tag) - - -class Food(models.Model): - name = models.CharField(max_length=50) - - tags = TaggableManager() - - def __unicode__(self): - return self.name - -class Pet(models.Model): - name = models.CharField(max_length=50) - - tags = TaggableManager() - - def __unicode__(self): - return self.name - -class HousePet(Pet): - trained = models.BooleanField() - - -# Test direct-tagging with custom through model - -class TaggedFood(TaggedItemBase): - content_object = models.ForeignKey('DirectFood') - -class TaggedPet(TaggedItemBase): - content_object = models.ForeignKey('DirectPet') - -class DirectFood(models.Model): - name = models.CharField(max_length=50) - - tags = TaggableManager(through="TaggedFood") - -class DirectPet(models.Model): - name = models.CharField(max_length=50) - - tags = TaggableManager(through=TaggedPet) - - def __unicode__(self): - return self.name - -class DirectHousePet(DirectPet): - trained = models.BooleanField() - - -# Test custom through model to model with custom PK - -class TaggedCustomPKFood(TaggedItemBase): - content_object = models.ForeignKey('CustomPKFood') - -class TaggedCustomPKPet(TaggedItemBase): - content_object = models.ForeignKey('CustomPKPet') - -class CustomPKFood(models.Model): - name = models.CharField(max_length=50, primary_key=True) - - tags = TaggableManager(through=TaggedCustomPKFood) - - def __unicode__(self): - return self.name - -class CustomPKPet(models.Model): - name = models.CharField(max_length=50, primary_key=True) - - tags = TaggableManager(through=TaggedCustomPKPet) - - def __unicode__(self): - return self.name - -class CustomPKHousePet(CustomPKPet): - trained = models.BooleanField() - -# Test custom through model to a custom tag model - -class OfficialTag(TagBase): - official = models.BooleanField() - -class OfficialThroughModel(GenericTaggedItemBase): - tag = models.ForeignKey(OfficialTag, related_name="tagged_items") - -class OfficialFood(models.Model): - name = models.CharField(max_length=50) - - tags = TaggableManager(through=OfficialThroughModel) - - def __unicode__(self): - return self.name - -class OfficialPet(models.Model): - name = models.CharField(max_length=50) - - tags = TaggableManager(through=OfficialThroughModel) - - def __unicode__(self): - return self.name - -class OfficialHousePet(OfficialPet): - trained = models.BooleanField() - - -class Media(models.Model): - tags = TaggableManager() - - class Meta: - abstract = True - -class Photo(Media): - pass - -class Movie(Media): - pass - - -class ArticleTag(Tag): - class Meta: - proxy = True - - def slugify(self, tag, i=None): - slug = "category-%s" % tag.lower() - - if i is not None: - slug += "-%d" % i - return slug - -class ArticleTaggedItem(TaggedItem): - class Meta: - proxy = True - - @classmethod - def tag_model(self): - return ArticleTag - -class Article(models.Model): - title = models.CharField(max_length=100) - - tags = TaggableManager(through=ArticleTaggedItem) diff --git a/awx/lib/site-packages/taggit/tests/runtests.py b/awx/lib/site-packages/taggit/tests/runtests.py deleted file mode 100644 index 3e52cf18fe..0000000000 --- a/awx/lib/site-packages/taggit/tests/runtests.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python -import os -import sys - -from django.conf import settings - -if not settings.configured: - settings.configure( - DATABASES={ - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - } - }, - INSTALLED_APPS=[ - 'django.contrib.contenttypes', - 'taggit', - 'taggit.tests', - ] - ) - - -from django.test.simple import DjangoTestSuiteRunner - -def runtests(): - runner = DjangoTestSuiteRunner() - failures = runner.run_tests(['tests'], verbosity=1, interactive=True) - sys.exit(failures) - -if __name__ == '__main__': - runtests(*sys.argv[1:]) - diff --git a/awx/lib/site-packages/taggit/tests/tests.py b/awx/lib/site-packages/taggit/tests/tests.py deleted file mode 100644 index 6282db84ab..0000000000 --- a/awx/lib/site-packages/taggit/tests/tests.py +++ /dev/null @@ -1,486 +0,0 @@ -from unittest import TestCase as UnitTestCase - -import django -from django.conf import settings -from django.core.exceptions import ValidationError -from django.db import connection -from django.test import TestCase, TransactionTestCase - -from taggit.managers import TaggableManager -from taggit.models import Tag, TaggedItem -from taggit.tests.forms import (FoodForm, DirectFoodForm, CustomPKFoodForm, - OfficialFoodForm) -from taggit.tests.models import (Food, Pet, HousePet, DirectFood, DirectPet, - DirectHousePet, TaggedPet, CustomPKFood, CustomPKPet, CustomPKHousePet, - TaggedCustomPKPet, OfficialFood, OfficialPet, OfficialHousePet, - OfficialThroughModel, OfficialTag, Photo, Movie, Article) -from taggit.utils import parse_tags, edit_string_for_tags - - -class BaseTaggingTest(object): - def assert_tags_equal(self, qs, tags, sort=True, attr="name"): - got = map(lambda tag: getattr(tag, attr), qs) - if sort: - got.sort() - tags.sort() - self.assertEqual(got, tags) - - def assert_num_queries(self, n, f, *args, **kwargs): - original_DEBUG = settings.DEBUG - settings.DEBUG = True - current = len(connection.queries) - try: - f(*args, **kwargs) - self.assertEqual( - len(connection.queries) - current, - n, - ) - finally: - settings.DEBUG = original_DEBUG - - def _get_form_str(self, form_str): - if django.VERSION >= (1, 3): - form_str %= { - "help_start": '<span class="helptext">', - "help_stop": "</span>" - } - else: - form_str %= { - "help_start": "", - "help_stop": "" - } - return form_str - - def assert_form_renders(self, form, html): - try: - self.assertHTMLEqual(str(form), self._get_form_str(html)) - except AttributeError: - self.assertEqual(str(form), self._get_form_str(html)) - - -class BaseTaggingTestCase(TestCase, BaseTaggingTest): - pass - -class BaseTaggingTransactionTestCase(TransactionTestCase, BaseTaggingTest): - pass - - -class TagModelTestCase(BaseTaggingTransactionTestCase): - food_model = Food - tag_model = Tag - - def test_unique_slug(self): - apple = self.food_model.objects.create(name="apple") - apple.tags.add("Red", "red") - - def test_update(self): - special = self.tag_model.objects.create(name="special") - special.save() - - def test_add(self): - apple = self.food_model.objects.create(name="apple") - yummy = self.tag_model.objects.create(name="yummy") - apple.tags.add(yummy) - - def test_slugify(self): - a = Article.objects.create(title="django-taggit 1.0 Released") - a.tags.add("awesome", "release", "AWESOME") - self.assert_tags_equal(a.tags.all(), [ - "category-awesome", - "category-release", - "category-awesome-1" - ], attr="slug") - -class TagModelDirectTestCase(TagModelTestCase): - food_model = DirectFood - tag_model = Tag - -class TagModelCustomPKTestCase(TagModelTestCase): - food_model = CustomPKFood - tag_model = Tag - -class TagModelOfficialTestCase(TagModelTestCase): - food_model = OfficialFood - tag_model = OfficialTag - -class TaggableManagerTestCase(BaseTaggingTestCase): - food_model = Food - pet_model = Pet - housepet_model = HousePet - taggeditem_model = TaggedItem - tag_model = Tag - - def test_add_tag(self): - apple = self.food_model.objects.create(name="apple") - self.assertEqual(list(apple.tags.all()), []) - self.assertEqual(list(self.food_model.tags.all()), []) - - apple.tags.add('green') - self.assert_tags_equal(apple.tags.all(), ['green']) - self.assert_tags_equal(self.food_model.tags.all(), ['green']) - - pear = self.food_model.objects.create(name="pear") - pear.tags.add('green') - self.assert_tags_equal(pear.tags.all(), ['green']) - self.assert_tags_equal(self.food_model.tags.all(), ['green']) - - apple.tags.add('red') - self.assert_tags_equal(apple.tags.all(), ['green', 'red']) - self.assert_tags_equal(self.food_model.tags.all(), ['green', 'red']) - - self.assert_tags_equal( - self.food_model.tags.most_common(), - ['green', 'red'], - sort=False - ) - - apple.tags.remove('green') - self.assert_tags_equal(apple.tags.all(), ['red']) - self.assert_tags_equal(self.food_model.tags.all(), ['green', 'red']) - tag = self.tag_model.objects.create(name="delicious") - apple.tags.add(tag) - self.assert_tags_equal(apple.tags.all(), ["red", "delicious"]) - - apple.delete() - self.assert_tags_equal(self.food_model.tags.all(), ["green"]) - - def test_add_queries(self): - apple = self.food_model.objects.create(name="apple") - # 1 query to see which tags exist - # + 3 queries to create the tags. - # + 6 queries to create the intermediary things (including SELECTs, to - # make sure we don't double create. - self.assert_num_queries(10, apple.tags.add, "red", "delicious", "green") - - pear = self.food_model.objects.create(name="pear") - # 1 query to see which tags exist - # + 4 queries to create the intermeidary things (including SELECTs, to - # make sure we dont't double create. - self.assert_num_queries(5, pear.tags.add, "green", "delicious") - - self.assert_num_queries(0, pear.tags.add) - - def test_require_pk(self): - food_instance = self.food_model() - self.assertRaises(ValueError, lambda: food_instance.tags.all()) - - def test_delete_obj(self): - apple = self.food_model.objects.create(name="apple") - apple.tags.add("red") - self.assert_tags_equal(apple.tags.all(), ["red"]) - strawberry = self.food_model.objects.create(name="strawberry") - strawberry.tags.add("red") - apple.delete() - self.assert_tags_equal(strawberry.tags.all(), ["red"]) - - def test_delete_bulk(self): - apple = self.food_model.objects.create(name="apple") - kitty = self.pet_model.objects.create(pk=apple.pk, name="kitty") - - apple.tags.add("red", "delicious", "fruit") - kitty.tags.add("feline") - - self.food_model.objects.all().delete() - - self.assert_tags_equal(kitty.tags.all(), ["feline"]) - - def test_lookup_by_tag(self): - apple = self.food_model.objects.create(name="apple") - apple.tags.add("red", "green") - pear = self.food_model.objects.create(name="pear") - pear.tags.add("green") - - self.assertEqual( - list(self.food_model.objects.filter(tags__name__in=["red"])), - [apple] - ) - self.assertEqual( - list(self.food_model.objects.filter(tags__name__in=["green"])), - [apple, pear] - ) - - kitty = self.pet_model.objects.create(name="kitty") - kitty.tags.add("fuzzy", "red") - dog = self.pet_model.objects.create(name="dog") - dog.tags.add("woof", "red") - self.assertEqual( - list(self.food_model.objects.filter(tags__name__in=["red"]).distinct()), - [apple] - ) - - tag = self.tag_model.objects.get(name="woof") - self.assertEqual(list(self.pet_model.objects.filter(tags__in=[tag])), [dog]) - - cat = self.housepet_model.objects.create(name="cat", trained=True) - cat.tags.add("fuzzy") - - self.assertEqual( - map(lambda o: o.pk, self.pet_model.objects.filter(tags__name__in=["fuzzy"])), - [kitty.pk, cat.pk] - ) - - def test_exclude(self): - apple = self.food_model.objects.create(name="apple") - apple.tags.add("red", "green", "delicious") - - pear = self.food_model.objects.create(name="pear") - pear.tags.add("green", "delicious") - - guava = self.food_model.objects.create(name="guava") - - self.assertEqual( - sorted(map(lambda o: o.pk, self.food_model.objects.exclude(tags__name__in=["red"]))), - sorted([pear.pk, guava.pk]), - ) - - def test_similarity_by_tag(self): - """Test that pears are more similar to apples than watermelons""" - apple = self.food_model.objects.create(name="apple") - apple.tags.add("green", "juicy", "small", "sour") - - pear = self.food_model.objects.create(name="pear") - pear.tags.add("green", "juicy", "small", "sweet") - - watermelon = self.food_model.objects.create(name="watermelon") - watermelon.tags.add("green", "juicy", "large", "sweet") - - similar_objs = apple.tags.similar_objects() - self.assertEqual(similar_objs, [pear, watermelon]) - self.assertEqual(map(lambda x: x.similar_tags, similar_objs), [3, 2]) - - def test_tag_reuse(self): - apple = self.food_model.objects.create(name="apple") - apple.tags.add("juicy", "juicy") - self.assert_tags_equal(apple.tags.all(), ['juicy']) - - def test_query_traverse(self): - spot = self.pet_model.objects.create(name='Spot') - spike = self.pet_model.objects.create(name='Spike') - spot.tags.add('scary') - spike.tags.add('fluffy') - lookup_kwargs = { - '%s__name' % self.pet_model._meta.module_name: 'Spot' - } - self.assert_tags_equal( - self.tag_model.objects.filter(**lookup_kwargs), - ['scary'] - ) - - def test_taggeditem_unicode(self): - ross = self.pet_model.objects.create(name="ross") - # I keep Ross Perot for a pet, what's it to you? - ross.tags.add("president") - - self.assertEqual( - unicode(self.taggeditem_model.objects.all()[0]), - "ross tagged with president" - ) - - def test_abstract_subclasses(self): - p = Photo.objects.create() - p.tags.add("outdoors", "pretty") - self.assert_tags_equal( - p.tags.all(), - ["outdoors", "pretty"] - ) - - m = Movie.objects.create() - m.tags.add("hd") - self.assert_tags_equal( - m.tags.all(), - ["hd"], - ) - - def test_field_api(self): - # Check if tag field, which simulates m2m, has django-like api. - field = self.food_model._meta.get_field('tags') - self.assertTrue(hasattr(field, 'rel')) - self.assertTrue(hasattr(field, 'related')) - self.assertEqual(self.food_model, field.related.model) - - -class TaggableManagerDirectTestCase(TaggableManagerTestCase): - food_model = DirectFood - pet_model = DirectPet - housepet_model = DirectHousePet - taggeditem_model = TaggedPet - -class TaggableManagerCustomPKTestCase(TaggableManagerTestCase): - food_model = CustomPKFood - pet_model = CustomPKPet - housepet_model = CustomPKHousePet - taggeditem_model = TaggedCustomPKPet - - def test_require_pk(self): - # TODO with a charfield pk, pk is never None, so taggit has no way to - # tell if the instance is saved or not - pass - -class TaggableManagerOfficialTestCase(TaggableManagerTestCase): - food_model = OfficialFood - pet_model = OfficialPet - housepet_model = OfficialHousePet - taggeditem_model = OfficialThroughModel - tag_model = OfficialTag - - def test_extra_fields(self): - self.tag_model.objects.create(name="red") - self.tag_model.objects.create(name="delicious", official=True) - apple = self.food_model.objects.create(name="apple") - apple.tags.add("delicious", "red") - - pear = self.food_model.objects.create(name="Pear") - pear.tags.add("delicious") - - self.assertEqual( - map(lambda o: o.pk, self.food_model.objects.filter(tags__official=False)), - [apple.pk], - ) - - -class TaggableFormTestCase(BaseTaggingTestCase): - form_class = FoodForm - food_model = Food - - def test_form(self): - self.assertEqual(self.form_class.base_fields.keys(), ['name', 'tags']) - - f = self.form_class({'name': 'apple', 'tags': 'green, red, yummy'}) - self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr> -<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value="green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""") - f.save() - apple = self.food_model.objects.get(name='apple') - self.assert_tags_equal(apple.tags.all(), ['green', 'red', 'yummy']) - - f = self.form_class({'name': 'apple', 'tags': 'green, red, yummy, delicious'}, instance=apple) - f.save() - apple = self.food_model.objects.get(name='apple') - self.assert_tags_equal(apple.tags.all(), ['green', 'red', 'yummy', 'delicious']) - self.assertEqual(self.food_model.objects.count(), 1) - - f = self.form_class({"name": "raspberry"}) - self.assertFalse(f.is_valid()) - - f = self.form_class(instance=apple) - self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr> -<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value="delicious, green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""") - - apple.tags.add('has,comma') - f = self.form_class(instance=apple) - self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr> -<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value=""has,comma", delicious, green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""") - - apple.tags.add('has space') - f = self.form_class(instance=apple) - self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr> -<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value=""has space", "has,comma", delicious, green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""") - - def test_formfield(self): - tm = TaggableManager(verbose_name='categories', help_text='Add some categories', blank=True) - ff = tm.formfield() - self.assertEqual(ff.label, 'Categories') - self.assertEqual(ff.help_text, u'Add some categories') - self.assertEqual(ff.required, False) - - self.assertEqual(ff.clean(""), []) - - tm = TaggableManager() - ff = tm.formfield() - self.assertRaises(ValidationError, ff.clean, "") - -class TaggableFormDirectTestCase(TaggableFormTestCase): - form_class = DirectFoodForm - food_model = DirectFood - -class TaggableFormCustomPKTestCase(TaggableFormTestCase): - form_class = CustomPKFoodForm - food_model = CustomPKFood - -class TaggableFormOfficialTestCase(TaggableFormTestCase): - form_class = OfficialFoodForm - food_model = OfficialFood - - -class TagStringParseTestCase(UnitTestCase): - """ - Ported from Jonathan Buchanan's `django-tagging - <http://django-tagging.googlecode.com/>`_ - """ - - def test_with_simple_space_delimited_tags(self): - """ - Test with simple space-delimited tags. - """ - self.assertEqual(parse_tags('one'), [u'one']) - self.assertEqual(parse_tags('one two'), [u'one', u'two']) - self.assertEqual(parse_tags('one two three'), [u'one', u'three', u'two']) - self.assertEqual(parse_tags('one one two two'), [u'one', u'two']) - - def test_with_comma_delimited_multiple_words(self): - """ - Test with comma-delimited multiple words. - An unquoted comma in the input will trigger this. - """ - self.assertEqual(parse_tags(',one'), [u'one']) - self.assertEqual(parse_tags(',one two'), [u'one two']) - self.assertEqual(parse_tags(',one two three'), [u'one two three']) - self.assertEqual(parse_tags('a-one, a-two and a-three'), - [u'a-one', u'a-two and a-three']) - - def test_with_double_quoted_multiple_words(self): - """ - Test with double-quoted multiple words. - A completed quote will trigger this. Unclosed quotes are ignored. - """ - self.assertEqual(parse_tags('"one'), [u'one']) - self.assertEqual(parse_tags('"one two'), [u'one', u'two']) - self.assertEqual(parse_tags('"one two three'), [u'one', u'three', u'two']) - self.assertEqual(parse_tags('"one two"'), [u'one two']) - self.assertEqual(parse_tags('a-one "a-two and a-three"'), - [u'a-one', u'a-two and a-three']) - - def test_with_no_loose_commas(self): - """ - Test with no loose commas -- split on spaces. - """ - self.assertEqual(parse_tags('one two "thr,ee"'), [u'one', u'thr,ee', u'two']) - - def test_with_loose_commas(self): - """ - Loose commas - split on commas - """ - self.assertEqual(parse_tags('"one", two three'), [u'one', u'two three']) - - def test_tags_with_double_quotes_can_contain_commas(self): - """ - Double quotes can contain commas - """ - self.assertEqual(parse_tags('a-one "a-two, and a-three"'), - [u'a-one', u'a-two, and a-three']) - self.assertEqual(parse_tags('"two", one, one, two, "one"'), - [u'one', u'two']) - - def test_with_naughty_input(self): - """ - Test with naughty input. - """ - # Bad users! Naughty users! - self.assertEqual(parse_tags(None), []) - self.assertEqual(parse_tags(''), []) - self.assertEqual(parse_tags('"'), []) - self.assertEqual(parse_tags('""'), []) - self.assertEqual(parse_tags('"' * 7), []) - self.assertEqual(parse_tags(',,,,,,'), []) - self.assertEqual(parse_tags('",",",",",",","'), [u',']) - self.assertEqual(parse_tags('a-one "a-two" and "a-three'), - [u'a-one', u'a-three', u'a-two', u'and']) - - def test_recreation_of_tag_list_string_representations(self): - plain = Tag.objects.create(name='plain') - spaces = Tag.objects.create(name='spa ces') - comma = Tag.objects.create(name='com,ma') - self.assertEqual(edit_string_for_tags([plain]), u'plain') - self.assertEqual(edit_string_for_tags([plain, spaces]), u'"spa ces", plain') - self.assertEqual(edit_string_for_tags([plain, spaces, comma]), u'"com,ma", "spa ces", plain') - self.assertEqual(edit_string_for_tags([plain, comma]), u'"com,ma", plain') - self.assertEqual(edit_string_for_tags([comma, spaces]), u'"com,ma", "spa ces"') diff --git a/awx/lib/site-packages/taggit/utils.py b/awx/lib/site-packages/taggit/utils.py index 1b5e5a7f11..997c4f0a1b 100644 --- a/awx/lib/site-packages/taggit/utils.py +++ b/awx/lib/site-packages/taggit/utils.py @@ -1,5 +1,8 @@ -from django.utils.encoding import force_unicode +from __future__ import unicode_literals + +from django.utils.encoding import force_text from django.utils.functional import wraps +from django.utils import six def parse_tags(tagstring): @@ -16,13 +19,13 @@ def parse_tags(tagstring): if not tagstring: return [] - tagstring = force_unicode(tagstring) + tagstring = force_text(tagstring) # Special case - if there are no commas or double quotes in the # input, we don't *do* a recall... I mean, we know we only need to # split on spaces. - if u',' not in tagstring and u'"' not in tagstring: - words = list(set(split_strip(tagstring, u' '))) + if ',' not in tagstring and '"' not in tagstring: + words = list(set(split_strip(tagstring, ' '))) words.sort() return words @@ -36,39 +39,39 @@ def parse_tags(tagstring): i = iter(tagstring) try: while True: - c = i.next() - if c == u'"': + c = six.next(i) + if c == '"': if buffer: - to_be_split.append(u''.join(buffer)) + to_be_split.append(''.join(buffer)) buffer = [] # Find the matching quote open_quote = True - c = i.next() - while c != u'"': + c = six.next(i) + while c != '"': buffer.append(c) - c = i.next() + c = six.next(i) if buffer: - word = u''.join(buffer).strip() + word = ''.join(buffer).strip() if word: words.append(word) buffer = [] open_quote = False else: - if not saw_loose_comma and c == u',': + if not saw_loose_comma and c == ',': saw_loose_comma = True buffer.append(c) except StopIteration: # If we were parsing an open quote which was never closed treat # the buffer as unquoted. if buffer: - if open_quote and u',' in buffer: + if open_quote and ',' in buffer: saw_loose_comma = True - to_be_split.append(u''.join(buffer)) + to_be_split.append(''.join(buffer)) if to_be_split: if saw_loose_comma: - delimiter = u',' + delimiter = ',' else: - delimiter = u' ' + delimiter = ' ' for chunk in to_be_split: words.extend(split_strip(chunk, delimiter)) words = list(set(words)) @@ -76,7 +79,7 @@ def parse_tags(tagstring): return words -def split_strip(string, delimiter=u','): +def split_strip(string, delimiter=','): """ Splits ``string`` on ``delimiter``, stripping each resulting string and returning a list of non-empty strings. @@ -110,11 +113,11 @@ def edit_string_for_tags(tags): names = [] for tag in tags: name = tag.name - if u',' in name or u' ' in name: + if ',' in name or ' ' in name: names.append('"%s"' % name) else: names.append(name) - return u', '.join(sorted(names)) + return ', '.join(sorted(names)) def require_instance_manager(func): diff --git a/awx/lib/site-packages/taggit/views.py b/awx/lib/site-packages/taggit/views.py index 1e407f41c9..68c06dbf9b 100644 --- a/awx/lib/site-packages/taggit/views.py +++ b/awx/lib/site-packages/taggit/views.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + from django.contrib.contenttypes.models import ContentType from django.shortcuts import get_object_or_404 from django.views.generic.list import ListView |