summaryrefslogtreecommitdiffstats
path: root/awx/lib/site-packages
diff options
context:
space:
mode:
Diffstat (limited to 'awx/lib/site-packages')
-rw-r--r--awx/lib/site-packages/README18
-rw-r--r--awx/lib/site-packages/amqp/__init__.py2
-rw-r--r--awx/lib/site-packages/amqp/basic_message.py3
-rw-r--r--awx/lib/site-packages/amqp/method_framing.py43
-rw-r--r--awx/lib/site-packages/amqp/transport.py82
-rw-r--r--awx/lib/site-packages/billiard/__init__.py8
-rw-r--r--awx/lib/site-packages/billiard/_ext.py8
-rw-r--r--awx/lib/site-packages/billiard/common.py34
-rw-r--r--awx/lib/site-packages/billiard/pool.py57
-rw-r--r--awx/lib/site-packages/billiard/process.py38
-rw-r--r--awx/lib/site-packages/billiard/queues.py10
-rw-r--r--awx/lib/site-packages/billiard/util.py8
-rw-r--r--awx/lib/site-packages/celery/__compat__.py7
-rw-r--r--awx/lib/site-packages/celery/__init__.py2
-rw-r--r--awx/lib/site-packages/celery/app/builtins.py5
-rw-r--r--awx/lib/site-packages/celery/app/task.py12
-rw-r--r--awx/lib/site-packages/celery/apps/worker.py2
-rw-r--r--awx/lib/site-packages/celery/backends/mongodb.py8
-rw-r--r--awx/lib/site-packages/celery/bin/celery.py4
-rw-r--r--awx/lib/site-packages/celery/canvas.py7
-rw-r--r--awx/lib/site-packages/celery/concurrency/eventlet.py3
-rw-r--r--awx/lib/site-packages/celery/contrib/migrate.py2
-rw-r--r--awx/lib/site-packages/celery/datastructures.py5
-rw-r--r--awx/lib/site-packages/celery/events/state.py34
-rw-r--r--awx/lib/site-packages/celery/loaders/__init__.py4
-rw-r--r--awx/lib/site-packages/celery/platforms.py22
-rw-r--r--awx/lib/site-packages/celery/result.py10
-rw-r--r--awx/lib/site-packages/celery/schedules.py19
-rw-r--r--awx/lib/site-packages/celery/task/trace.py9
-rw-r--r--awx/lib/site-packages/celery/tests/backends/test_base.py3
-rw-r--r--awx/lib/site-packages/celery/tests/backends/test_mongodb.py2
-rw-r--r--awx/lib/site-packages/celery/tests/tasks/test_canvas.py5
-rw-r--r--awx/lib/site-packages/celery/tests/tasks/test_tasks.py44
-rw-r--r--awx/lib/site-packages/celery/tests/utilities/test_encoding.py15
-rw-r--r--awx/lib/site-packages/celery/tests/utilities/test_term.py4
-rw-r--r--awx/lib/site-packages/celery/tests/worker/test_worker.py22
-rw-r--r--awx/lib/site-packages/celery/utils/imports.py11
-rw-r--r--awx/lib/site-packages/celery/utils/serialization.py37
-rw-r--r--awx/lib/site-packages/celery/utils/threads.py19
-rw-r--r--awx/lib/site-packages/celery/worker/__init__.py83
-rw-r--r--awx/lib/site-packages/celery/worker/buckets.py10
-rw-r--r--awx/lib/site-packages/celery/worker/consumer.py18
-rw-r--r--awx/lib/site-packages/celery/worker/mediator.py2
-rw-r--r--awx/lib/site-packages/django_extensions/__init__.py2
-rw-r--r--awx/lib/site-packages/django_extensions/admin/__init__.py6
-rw-r--r--awx/lib/site-packages/django_extensions/admin/widgets.py5
-rw-r--r--awx/lib/site-packages/django_extensions/db/fields/__init__.py29
-rw-r--r--awx/lib/site-packages/django_extensions/db/fields/encrypted.py44
-rw-r--r--awx/lib/site-packages/django_extensions/db/fields/json.py9
-rw-r--r--awx/lib/site-packages/django_extensions/future_1_5.py16
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/dumpscript.py107
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/export_emails.py18
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/graph_models.py2
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/passwd.py6
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/pipchecker.py39
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py6
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/reset_db.py4
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py12
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py75
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py7
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py7
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py4
-rw-r--r--awx/lib/site-packages/django_extensions/management/commands/sqldiff.py8
-rw-r--r--awx/lib/site-packages/django_extensions/management/modelviz.py6
-rw-r--r--awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py2
-rw-r--r--awx/lib/site-packages/django_extensions/templatetags/widont.py5
-rw-r--r--awx/lib/site-packages/django_extensions/tests/json_field.py8
-rw-r--r--awx/lib/site-packages/django_extensions/tests/test_dumpscript.py4
-rw-r--r--awx/lib/site-packages/django_extensions/tests/utils.py21
-rw-r--r--awx/lib/site-packages/django_extensions/tests/uuid_field.py23
-rw-r--r--awx/lib/site-packages/django_extensions/utils/dia2django.py5
-rw-r--r--awx/lib/site-packages/django_extensions/utils/uuid.py566
-rw-r--r--awx/lib/site-packages/djcelery/__init__.py2
-rw-r--r--awx/lib/site-packages/djcelery/admin.py8
-rw-r--r--awx/lib/site-packages/djcelery/picklefield.py8
-rw-r--r--awx/lib/site-packages/djcelery/schedulers.py4
-rw-r--r--awx/lib/site-packages/djcelery/snapshot.py4
-rw-r--r--awx/lib/site-packages/djcelery/utils.py3
-rw-r--r--awx/lib/site-packages/kombu/__init__.py2
-rw-r--r--awx/lib/site-packages/kombu/abstract.py4
-rw-r--r--awx/lib/site-packages/kombu/connection.py80
-rw-r--r--awx/lib/site-packages/kombu/entity.py15
-rw-r--r--awx/lib/site-packages/kombu/messaging.py35
-rw-r--r--awx/lib/site-packages/kombu/mixins.py12
-rw-r--r--awx/lib/site-packages/kombu/pidbox.py13
-rw-r--r--awx/lib/site-packages/kombu/tests/__init__.py19
-rw-r--r--awx/lib/site-packages/kombu/tests/test_entities.py4
-rw-r--r--awx/lib/site-packages/kombu/tests/test_messaging.py6
-rw-r--r--awx/lib/site-packages/kombu/tests/test_serialization.py15
-rw-r--r--awx/lib/site-packages/kombu/tests/test_utils.py4
-rw-r--r--awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py3
-rw-r--r--awx/lib/site-packages/kombu/tests/utilities/test_encoding.py40
-rw-r--r--awx/lib/site-packages/kombu/transport/librabbitmq.py40
-rw-r--r--awx/lib/site-packages/kombu/transport/mongodb.py49
-rw-r--r--awx/lib/site-packages/kombu/transport/pyamqp.py30
-rw-r--r--awx/lib/site-packages/kombu/transport/redis.py13
-rw-r--r--awx/lib/site-packages/kombu/utils/__init__.py2
-rw-r--r--awx/lib/site-packages/kombu/utils/encoding.py2
-rw-r--r--awx/lib/site-packages/kombu/utils/eventio.py33
-rw-r--r--awx/lib/site-packages/rest_framework/__init__.py2
-rw-r--r--awx/lib/site-packages/rest_framework/authentication.py42
-rw-r--r--awx/lib/site-packages/rest_framework/authtoken/admin.py11
-rw-r--r--awx/lib/site-packages/rest_framework/authtoken/models.py4
-rw-r--r--awx/lib/site-packages/rest_framework/compat.py65
-rw-r--r--awx/lib/site-packages/rest_framework/exceptions.py7
-rw-r--r--awx/lib/site-packages/rest_framework/fields.py70
-rw-r--r--awx/lib/site-packages/rest_framework/filters.py5
-rw-r--r--awx/lib/site-packages/rest_framework/generics.py4
-rw-r--r--awx/lib/site-packages/rest_framework/parsers.py25
-rw-r--r--awx/lib/site-packages/rest_framework/permissions.py2
-rw-r--r--awx/lib/site-packages/rest_framework/relations.py16
-rw-r--r--awx/lib/site-packages/rest_framework/renderers.py15
-rw-r--r--awx/lib/site-packages/rest_framework/request.py20
-rw-r--r--awx/lib/site-packages/rest_framework/response.py2
-rw-r--r--awx/lib/site-packages/rest_framework/routers.py37
-rw-r--r--awx/lib/site-packages/rest_framework/runtests/settings.py2
-rw-r--r--awx/lib/site-packages/rest_framework/serializers.py54
-rw-r--r--awx/lib/site-packages/rest_framework/settings.py8
-rw-r--r--awx/lib/site-packages/rest_framework/templates/rest_framework/base.html2
-rw-r--r--awx/lib/site-packages/rest_framework/test.py157
-rw-r--r--awx/lib/site-packages/rest_framework/tests/description.py26
-rw-r--r--awx/lib/site-packages/rest_framework/tests/models.py2
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_authentication.py87
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_decorators.py11
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_description.py13
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_fields.py32
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_filters.py4
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_generics.py60
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py38
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_negotiation.py4
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_pagination.py6
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_permissions.py35
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py4
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_renderers.py8
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_request.py19
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_reverse.py4
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_routers.py109
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_serializer.py89
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_testing.py115
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_throttling.py144
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py4
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_validation.py8
-rw-r--r--awx/lib/site-packages/rest_framework/tests/test_views.py6
-rw-r--r--awx/lib/site-packages/rest_framework/tests/utils.py40
-rw-r--r--awx/lib/site-packages/rest_framework/throttling.py42
-rw-r--r--awx/lib/site-packages/rest_framework/utils/formatting.py4
-rw-r--r--awx/lib/site-packages/rest_framework/views.py26
-rw-r--r--awx/lib/site-packages/south/__init__.py2
-rw-r--r--awx/lib/site-packages/south/creator/actions.py25
-rw-r--r--awx/lib/site-packages/south/creator/changes.py28
-rw-r--r--awx/lib/site-packages/south/db/__init__.py3
-rw-r--r--awx/lib/site-packages/south/db/firebird.py10
-rw-r--r--awx/lib/site-packages/south/db/generic.py18
-rw-r--r--awx/lib/site-packages/south/db/oracle.py22
-rw-r--r--awx/lib/site-packages/south/db/sql_server/pyodbc.py13
-rw-r--r--awx/lib/site-packages/south/db/sqlite3.py11
-rw-r--r--awx/lib/site-packages/south/hacks/django_1_0.py3
-rw-r--r--awx/lib/site-packages/south/management/commands/datamigration.py19
-rw-r--r--awx/lib/site-packages/south/management/commands/schemamigration.py5
-rw-r--r--awx/lib/site-packages/south/management/commands/syncdb.py2
-rw-r--r--awx/lib/site-packages/south/migration/migrators.py5
-rw-r--r--awx/lib/site-packages/south/orm.py16
-rw-r--r--awx/lib/site-packages/south/tests/db.py91
-rw-r--r--awx/lib/site-packages/south/utils/__init__.py6
-rw-r--r--awx/lib/site-packages/south/utils/py3.py7
-rw-r--r--awx/lib/site-packages/taggit/__init__.py2
-rw-r--r--awx/lib/site-packages/taggit/admin.py2
-rw-r--r--awx/lib/site-packages/taggit/forms.py5
-rw-r--r--awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.mobin0 -> 1061 bytes
-rw-r--r--awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.po64
-rw-r--r--awx/lib/site-packages/taggit/managers.py258
-rw-r--r--awx/lib/site-packages/taggit/migrations/0001_initial.py38
-rw-r--r--awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py22
-rw-r--r--awx/lib/site-packages/taggit/models.py69
-rw-r--r--awx/lib/site-packages/taggit/tests/__init__.py0
-rw-r--r--awx/lib/site-packages/taggit/tests/forms.py20
-rw-r--r--awx/lib/site-packages/taggit/tests/models.py143
-rw-r--r--awx/lib/site-packages/taggit/tests/runtests.py31
-rw-r--r--awx/lib/site-packages/taggit/tests/tests.py486
-rw-r--r--awx/lib/site-packages/taggit/utils.py41
-rw-r--r--awx/lib/site-packages/taggit/views.py2
181 files changed, 2850 insertions, 2346 deletions
diff --git a/awx/lib/site-packages/README b/awx/lib/site-packages/README
index d9d46fd10f..aa3205363d 100644
--- a/awx/lib/site-packages/README
+++ b/awx/lib/site-packages/README
@@ -1,17 +1,17 @@
Local versions of third-party packages required by AWX. Package names and
versions are listed below, along with notes on which files are included.
-amqp-1.0.11 (amqp/*)
+amqp-1.0.13 (amqp/*)
anyjson-0.3.3 (anyjson/*)
-billiard-2.7.3.28 (billiard/*, funtests/*, excluded _billiard.so)
-celery-3.0.19 (celery/*, excluded bin/celery* and bin/camqadm)
-django-celery-3.0.17 (djcelery/*, excluded bin/djcelerymon)
-django-extensions-1.1.1 (django_extensions/*)
+billiard-2.7.3.32 (billiard/*, funtests/*, excluded _billiard.so)
+celery-3.0.22 (celery/*, excluded bin/celery* and bin/camqadm)
+django-celery-3.0.21 (djcelery/*, excluded bin/djcelerymon)
+django-extensions-1.2.0 (django_extensions/*)
django-jsonfield-0.9.10 (jsonfield/*)
-django-taggit-0.10a1 (taggit/*)
-djangorestframework-2.3.5 (rest_framework/*)
+django-taggit-0.10 (taggit/*)
+djangorestframework-2.3.7 (rest_framework/*)
importlib-1.0.2 (importlib/*, needed for Python 2.6 support)
-kombu-2.5.10 (kombu/*)
+kombu-2.5.14 (kombu/*)
Markdown-2.3.1 (markdown/*, excluded bin/markdown_py)
ordereddict-1.1 (ordereddict.py, needed for Python 2.6 support)
pexpect-2.4 (pexpect.py, pxssh.py, fdpexpect.py, FSM.py, screen.py, ANSI.py)
@@ -19,4 +19,4 @@ python-dateutil-2.1 (dateutil/*)
pytz-2013b (pytz/*)
requests-1.2.3 (requests/*)
six-1.3.0 (six.py)
-South-0.8.1 (south/*)
+South-0.8.2 (south/*)
diff --git a/awx/lib/site-packages/amqp/__init__.py b/awx/lib/site-packages/amqp/__init__.py
index 82abe7f721..fca3bb317d 100644
--- a/awx/lib/site-packages/amqp/__init__.py
+++ b/awx/lib/site-packages/amqp/__init__.py
@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
from __future__ import absolute_import
-VERSION = (1, 0, 11)
+VERSION = (1, 0, 13)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Barry Pederson'
__maintainer__ = 'Ask Solem'
diff --git a/awx/lib/site-packages/amqp/basic_message.py b/awx/lib/site-packages/amqp/basic_message.py
index dc7c7a1237..192ede90b2 100644
--- a/awx/lib/site-packages/amqp/basic_message.py
+++ b/awx/lib/site-packages/amqp/basic_message.py
@@ -44,7 +44,7 @@ class Message(GenericContent):
('cluster_id', 'shortstr')
]
- def __init__(self, body='', children=None, **properties):
+ def __init__(self, body='', children=None, channel=None, **properties):
"""Expected arg types
body: string
@@ -107,6 +107,7 @@ class Message(GenericContent):
"""
super(Message, self).__init__(**properties)
self.body = body
+ self.channel = channel
def __eq__(self, other):
"""Check if the properties and bodies of this Message and another
diff --git a/awx/lib/site-packages/amqp/method_framing.py b/awx/lib/site-packages/amqp/method_framing.py
index 0225ddd35f..a26db6672b 100644
--- a/awx/lib/site-packages/amqp/method_framing.py
+++ b/awx/lib/site-packages/amqp/method_framing.py
@@ -16,9 +16,8 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
from __future__ import absolute_import
-from collections import defaultdict
+from collections import defaultdict, deque
from struct import pack, unpack
-from Queue import Queue
try:
bytes
@@ -61,12 +60,17 @@ class _PartialMessage(object):
self.complete = (self.body_size == 0)
def add_payload(self, payload):
- self.body_parts.append(payload)
+ parts = self.body_parts
self.body_received += len(payload)
-
if self.body_received == self.body_size:
- self.msg.body = bytes().join(self.body_parts)
+ if parts:
+ parts.append(payload)
+ self.msg.body = bytes().join(parts)
+ else:
+ self.msg.body = payload
self.complete = True
+ else:
+ parts.append(payload)
class MethodReader(object):
@@ -86,7 +90,7 @@ class MethodReader(object):
def __init__(self, source):
self.source = source
- self.queue = Queue()
+ self.queue = deque()
self.running = False
self.partial_messages = {}
self.heartbeats = 0
@@ -94,32 +98,33 @@ class MethodReader(object):
self.expected_types = defaultdict(lambda: 1)
# not an actual byte count, just incremented whenever we receive
self.bytes_recv = 0
+ self._quick_put = self.queue.append
+ self._quick_get = self.queue.popleft
def _next_method(self):
"""Read the next method from the source, once one complete method has
been assembled it is placed in the internal queue."""
- empty = self.queue.empty
+ queue = self.queue
+ put = self._quick_put
read_frame = self.source.read_frame
- while empty():
+ while not queue:
try:
frame_type, channel, payload = read_frame()
except Exception, e:
#
# Connection was closed? Framing Error?
#
- self.queue.put(e)
+ put(e)
break
self.bytes_recv += 1
if frame_type not in (self.expected_types[channel], 8):
- self.queue.put((
+ put((
channel,
AMQPError(
'Received frame type %s while expecting type: %s' % (
- frame_type, self.expected_types[channel])
- ),
- ))
+ frame_type, self.expected_types[channel]))))
elif frame_type == 1:
self._process_method_frame(channel, payload)
elif frame_type == 2:
@@ -144,7 +149,7 @@ class MethodReader(object):
self.partial_messages[channel] = _PartialMessage(method_sig, args)
self.expected_types[channel] = 2
else:
- self.queue.put((channel, method_sig, args, None))
+ self._quick_put((channel, method_sig, args, None))
def _process_content_header(self, channel, payload):
"""Process Content Header frames"""
@@ -155,8 +160,8 @@ class MethodReader(object):
#
# a bodyless message, we're done
#
- self.queue.put((channel, partial.method_sig,
- partial.args, partial.msg))
+ self._quick_put((channel, partial.method_sig,
+ partial.args, partial.msg))
self.partial_messages.pop(channel, None)
self.expected_types[channel] = 1
else:
@@ -174,15 +179,15 @@ class MethodReader(object):
# Stick the message in the queue and go back to
# waiting for method frames
#
- self.queue.put((channel, partial.method_sig,
- partial.args, partial.msg))
+ self._quick_put((channel, partial.method_sig,
+ partial.args, partial.msg))
self.partial_messages.pop(channel, None)
self.expected_types[channel] = 1
def read_method(self):
"""Read a method from the peer."""
self._next_method()
- m = self.queue.get()
+ m = self._quick_get()
if isinstance(m, Exception):
raise m
if isinstance(m, tuple) and isinstance(m[1], AMQPError):
diff --git a/awx/lib/site-packages/amqp/transport.py b/awx/lib/site-packages/amqp/transport.py
index 8092c5fec3..b441a11ced 100644
--- a/awx/lib/site-packages/amqp/transport.py
+++ b/awx/lib/site-packages/amqp/transport.py
@@ -52,6 +52,8 @@ from .exceptions import AMQPError
AMQP_PORT = 5672
+EMPTY_BUFFER = bytes()
+
# Yes, Advanced Message Queuing Protocol Protocol is redundant
AMQP_PROTOCOL_HEADER = 'AMQP\x01\x01\x00\x09'.encode('latin_1')
@@ -139,11 +141,12 @@ class _AbstractTransport(object):
self.sock.close()
self.sock = None
- def read_frame(self):
+ def read_frame(self, unpack=unpack):
"""Read an AMQP frame."""
- frame_type, channel, size = unpack('>BHI', self._read(7, True))
- payload = self._read(size)
- ch = ord(self._read(1))
+ read = self._read
+ frame_type, channel, size = unpack('>BHI', read(7, True))
+ payload = read(size)
+ ch = ord(read(1))
if ch == 206: # '\xce'
return frame_type, channel, payload
else:
@@ -164,7 +167,7 @@ class SSLTransport(_AbstractTransport):
def __init__(self, host, connect_timeout, ssl):
if isinstance(ssl, dict):
self.sslopts = ssl
- self.sslobj = None
+ self._read_buffer = EMPTY_BUFFER
super(SSLTransport, self).__init__(host, connect_timeout)
def _setup_transport(self):
@@ -173,43 +176,51 @@ class SSLTransport(_AbstractTransport):
lower version."""
if HAVE_PY26_SSL:
if hasattr(self, 'sslopts'):
- self.sslobj = ssl.wrap_socket(self.sock, **self.sslopts)
+ self.sock = ssl.wrap_socket(self.sock, **self.sslopts)
else:
- self.sslobj = ssl.wrap_socket(self.sock)
- self.sslobj.do_handshake()
+ self.sock = ssl.wrap_socket(self.sock)
+ self.sock.do_handshake()
else:
- self.sslobj = socket.ssl(self.sock)
+ self.sock = socket.ssl(self.sock)
+ self._quick_recv = self.sock.read
def _shutdown_transport(self):
"""Unwrap a Python 2.6 SSL socket, so we can call shutdown()"""
- if HAVE_PY26_SSL and (self.sslobj is not None):
- self.sock = self.sslobj.unwrap()
- self.sslobj = None
-
- def _read(self, n, initial=False):
- """It seems that SSL Objects read() method may not supply as much
- as you're asking for, at least with extremely large messages.
- somewhere > 16K - found this in the test_channel.py test_large
- unittest."""
- result = ''
-
- while len(result) < n:
+ if HAVE_PY26_SSL and self.sock is not None:
try:
- s = self.sslobj.read(n - len(result))
+ unwrap = self.sock.unwrap
+ except AttributeError:
+ return
+ self.sock = unwrap()
+
+ def _read(self, n, initial=False,
+ _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)):
+ # According to SSL_read(3), it can at most return 16kb of data.
+ # Thus, we use an internal read buffer like TCPTransport._read
+ # to get the exact number of bytes wanted.
+ recv = self._quick_recv
+ rbuf = self._read_buffer
+ while len(rbuf) < n:
+ try:
+ s = recv(131072) # see note above
except socket.error, exc:
- if not initial and exc.errno in (errno.EAGAIN, errno.EINTR):
+ # ssl.sock.read may cause ENOENT if the
+ # operation couldn't be performed (Issue celery#1414).
+ if not initial and exc.errno in _errnos:
continue
- raise
+ raise exc
if not s:
raise IOError('Socket closed')
- result += s
+ rbuf += s
+ result, self._read_buffer = rbuf[:n], rbuf[n:]
return result
def _write(self, s):
"""Write a string out to the SSL socket fully."""
+ write = self.sock.write
while s:
- n = self.sslobj.write(s)
+ n = write(s)
if not n:
raise IOError('Socket closed')
s = s[n:]
@@ -222,24 +233,25 @@ class TCPTransport(_AbstractTransport):
"""Setup to _write() directly to the socket, and
do our own buffered reads."""
self._write = self.sock.sendall
- self._read_buffer = bytes()
+ self._read_buffer = EMPTY_BUFFER
+ self._quick_recv = self.sock.recv
- def _read(self, n, initial=False):
+ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
"""Read exactly n bytes from the socket"""
- while len(self._read_buffer) < n:
+ recv = self._quick_recv
+ rbuf = self._read_buffer
+ while len(rbuf) < n:
try:
- s = self.sock.recv(65536)
+ s = recv(131072)
except socket.error, exc:
- if not initial and exc.errno in (errno.EAGAIN, errno.EINTR):
+ if not initial and exc.errno in _errnos:
continue
raise
if not s:
raise IOError('Socket closed')
- self._read_buffer += s
-
- result = self._read_buffer[:n]
- self._read_buffer = self._read_buffer[n:]
+ rbuf += s
+ result, self._read_buffer = rbuf[:n], rbuf[n:]
return result
diff --git a/awx/lib/site-packages/billiard/__init__.py b/awx/lib/site-packages/billiard/__init__.py
index 5b22a4fa1a..846d01c168 100644
--- a/awx/lib/site-packages/billiard/__init__.py
+++ b/awx/lib/site-packages/billiard/__init__.py
@@ -20,7 +20,7 @@
from __future__ import absolute_import
from __future__ import with_statement
-VERSION = (2, 7, 3, 28)
+VERSION = (2, 7, 3, 32)
__version__ = ".".join(map(str, VERSION[0:4])) + "".join(VERSION[4:])
__author__ = 'R Oudkerk / Python Software Foundation'
__author_email__ = 'python-dev@python.org'
@@ -232,7 +232,11 @@ def JoinableQueue(maxsize=0):
return JoinableQueue(maxsize)
-def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None):
+def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None,
+ timeout=None, soft_timeout=None, lost_worker_timeout=None,
+ max_restarts=None, max_restart_freq=1, on_process_up=None,
+ on_process_down=None, on_timeout_set=None, on_timeout_cancel=None,
+ threads=True, semaphore=None, putlocks=False, allow_restart=False):
'''
Returns a process pool object
'''
diff --git a/awx/lib/site-packages/billiard/_ext.py b/awx/lib/site-packages/billiard/_ext.py
index a02a156669..7d9caf01ae 100644
--- a/awx/lib/site-packages/billiard/_ext.py
+++ b/awx/lib/site-packages/billiard/_ext.py
@@ -4,6 +4,11 @@ import sys
supports_exec = True
+try:
+ import _winapi as win32
+except ImportError: # pragma: no cover
+ win32 = None
+
if sys.platform.startswith("java"):
_billiard = None
else:
@@ -18,7 +23,8 @@ else:
from multiprocessing.connection import Connection # noqa
PipeConnection = getattr(_billiard, "PipeConnection", None)
- win32 = getattr(_billiard, "win32", None)
+ if win32 is None:
+ win32 = getattr(_billiard, "win32", None) # noqa
def ensure_multiprocessing():
diff --git a/awx/lib/site-packages/billiard/common.py b/awx/lib/site-packages/billiard/common.py
index 04c37ea67d..5c367fd879 100644
--- a/awx/lib/site-packages/billiard/common.py
+++ b/awx/lib/site-packages/billiard/common.py
@@ -1,12 +1,41 @@
+# -*- coding: utf-8 -*-
+"""
+This module contains utilities added by billiard, to keep
+"non-core" functionality out of ``.util``."""
from __future__ import absolute_import
import signal
import sys
from time import time
+import pickle as pypickle
+try:
+ import cPickle as cpickle
+except ImportError: # pragma: no cover
+ cpickle = None # noqa
from .exceptions import RestartFreqExceeded
+if sys.version_info < (2, 6): # pragma: no cover
+ # cPickle does not use absolute_imports
+ pickle = pypickle
+ pickle_load = pypickle.load
+ pickle_loads = pypickle.loads
+else:
+ pickle = cpickle or pypickle
+ pickle_load = pickle.load
+ pickle_loads = pickle.loads
+
+# cPickle.loads does not support buffer() objects,
+# but we can just create a StringIO and use load.
+if sys.version_info[0] == 3:
+ from io import BytesIO
+else:
+ try:
+ from cStringIO import StringIO as BytesIO # noqa
+ except ImportError:
+ from StringIO import StringIO as BytesIO # noqa
+
TERMSIGS = (
'SIGHUP',
'SIGQUIT',
@@ -30,6 +59,11 @@ TERMSIGS = (
)
+def pickle_loads(s, load=pickle_load):
+ # used to support buffer objects
+ return load(BytesIO(s))
+
+
def _shutdown_cleanup(signum, frame):
sys.exit(-(256 - signum))
diff --git a/awx/lib/site-packages/billiard/pool.py b/awx/lib/site-packages/billiard/pool.py
index 71c637656a..33e0fe0d29 100644
--- a/awx/lib/site-packages/billiard/pool.py
+++ b/awx/lib/site-packages/billiard/pool.py
@@ -17,7 +17,6 @@ from __future__ import with_statement
import collections
import errno
import itertools
-import logging
import os
import platform
import signal
@@ -29,7 +28,7 @@ import warnings
from . import Event, Process, cpu_count
from . import util
-from .common import reset_signals, restart_state
+from .common import pickle_loads, reset_signals, restart_state
from .compat import get_errno
from .einfo import ExceptionInfo
from .exceptions import (
@@ -163,15 +162,6 @@ class LaxBoundedSemaphore(_Semaphore):
if self._value < self._initial_value:
self._value += 1
cond.notify_all()
- if __debug__:
- self._note(
- "%s.release: success, value=%s", self, self._value,
- )
- else:
- if __debug__:
- self._note(
- "%s.release: success, value=%s (unchanged)" % (
- self, self._value))
def clear(self):
while self._value < self._initial_value:
@@ -184,14 +174,6 @@ class LaxBoundedSemaphore(_Semaphore):
if self._Semaphore__value < self._initial_value:
self._Semaphore__value += 1
cond.notifyAll()
- if __debug__:
- self._note("%s.release: success, value=%s",
- self, self._Semaphore__value)
- else:
- if __debug__:
- self._note(
- "%s.release: success, value=%s (unchanged)" % (
- self, self._Semaphore__value))
def clear(self): # noqa
while self._Semaphore__value < self._initial_value:
@@ -233,28 +215,26 @@ def soft_timeout_sighandler(signum, frame):
def worker(inqueue, outqueue, initializer=None, initargs=(),
maxtasks=None, sentinel=None):
- # Re-init logging system.
- # Workaround for http://bugs.python.org/issue6721#msg140215
- # Python logging module uses RLock() objects which are broken after
- # fork. This can result in a deadlock (Issue #496).
- logger_names = logging.Logger.manager.loggerDict.keys()
- logger_names.append(None) # for root logger
- for name in logger_names:
- for handler in logging.getLogger(name).handlers:
- handler.createLock()
- logging._lock = threading.RLock()
-
pid = os.getpid()
assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
put = outqueue.put
get = inqueue.get
+ loads = pickle_loads
if hasattr(inqueue, '_reader'):
- def poll(timeout):
- if inqueue._reader.poll(timeout):
- return True, get()
- return False, None
+ if hasattr(inqueue, 'get_payload') and inqueue.get_payload:
+ get_payload = inqueue.get_payload
+
+ def poll(timeout):
+ if inqueue._reader.poll(timeout):
+ return True, loads(get_payload())
+ return False, None
+ else:
+ def poll(timeout):
+ if inqueue._reader.poll(timeout):
+ return True, get()
+ return False, None
else:
def poll(timeout): # noqa
@@ -1236,8 +1216,13 @@ class Pool(object):
return result
def terminate_job(self, pid, sig=None):
- self.signalled.add(pid)
- _kill(pid, sig or signal.SIGTERM)
+ try:
+ _kill(pid, sig or signal.SIGTERM)
+ except OSError, exc:
+ if get_errno(exc) != errno.ESRCH:
+ raise
+ else:
+ self.signalled.add(pid)
def map_async(self, func, iterable, chunksize=None,
callback=None, error_callback=None):
diff --git a/awx/lib/site-packages/billiard/process.py b/awx/lib/site-packages/billiard/process.py
index 79c850bd4b..b8418229db 100644
--- a/awx/lib/site-packages/billiard/process.py
+++ b/awx/lib/site-packages/billiard/process.py
@@ -19,6 +19,8 @@ import sys
import signal
import itertools
import binascii
+import logging
+import threading
from .compat import bytes
try:
@@ -45,9 +47,17 @@ def current_process():
def _cleanup():
# check for processes which have finished
- for p in list(_current_process._children):
- if p._popen.poll() is not None:
- _current_process._children.discard(p)
+ if _current_process is not None:
+ for p in list(_current_process._children):
+ if p._popen.poll() is not None:
+ _current_process._children.discard(p)
+
+
+def _maybe_flush(f):
+ try:
+ f.flush()
+ except (AttributeError, EnvironmentError, NotImplementedError):
+ pass
def active_children(_cleanup=_cleanup):
@@ -59,7 +69,9 @@ def active_children(_cleanup=_cleanup):
except TypeError:
# called after gc collect so _cleanup does not exist anymore
return []
- return list(_current_process._children)
+ if _current_process is not None:
+ return list(_current_process._children)
+ return []
class Process(object):
@@ -242,6 +254,18 @@ class Process(object):
pass
old_process = _current_process
_current_process = self
+
+ # Re-init logging system.
+ # Workaround for http://bugs.python.org/issue6721#msg140215
+ # Python logging module uses RLock() objects which are broken after
+ # fork. This can result in a deadlock (Celery Issue #496).
+ logger_names = logging.Logger.manager.loggerDict.keys()
+ logger_names.append(None) # for root logger
+ for name in logger_names:
+ for handler in logging.getLogger(name).handlers:
+ handler.createLock()
+ logging._lock = threading.RLock()
+
try:
util._finalizer_registry.clear()
util._run_after_forkers()
@@ -262,7 +286,7 @@ class Process(object):
exitcode = e.args[0]
else:
sys.stderr.write(str(e.args[0]) + '\n')
- sys.stderr.flush()
+ _maybe_flush(sys.stderr)
exitcode = 0 if isinstance(e.args[0], str) else 1
except:
exitcode = 1
@@ -273,8 +297,8 @@ class Process(object):
finally:
util.info('process %s exiting with exitcode %d',
self.pid, exitcode)
- sys.stdout.flush()
- sys.stderr.flush()
+ _maybe_flush(sys.stdout)
+ _maybe_flush(sys.stderr)
return exitcode
#
diff --git a/awx/lib/site-packages/billiard/queues.py b/awx/lib/site-packages/billiard/queues.py
index a44c2b47a7..2554a5ef07 100644
--- a/awx/lib/site-packages/billiard/queues.py
+++ b/awx/lib/site-packages/billiard/queues.py
@@ -334,6 +334,10 @@ class SimpleQueue(object):
def _make_methods(self):
recv = self._reader.recv
+ try:
+ recv_payload = self._reader.recv_payload
+ except AttributeError:
+ recv_payload = None # C extension not installed
rlock = self._rlock
def get():
@@ -341,6 +345,12 @@ class SimpleQueue(object):
return recv()
self.get = get
+ if recv_payload is not None:
+ def get_payload():
+ with rlock:
+ return recv_payload()
+ self.get_payload = get_payload
+
if self._wlock is None:
# writes to a message oriented win32 pipe are atomic
self.put = self._writer.send
diff --git a/awx/lib/site-packages/billiard/util.py b/awx/lib/site-packages/billiard/util.py
index f669512049..76c8431a6f 100644
--- a/awx/lib/site-packages/billiard/util.py
+++ b/awx/lib/site-packages/billiard/util.py
@@ -244,7 +244,9 @@ class Finalize(object):
return x + '>'
-def _run_finalizers(minpriority=None):
+def _run_finalizers(minpriority=None,
+ _finalizer_registry=_finalizer_registry,
+ sub_debug=sub_debug, error=error):
'''
Run all finalizers whose exit priority is not None and at least minpriority
@@ -280,7 +282,9 @@ def is_exiting():
return _exiting or _exiting is None
-def _exit_function():
+def _exit_function(info=info, debug=debug,
+ active_children=active_children,
+ _run_finalizers=_run_finalizers):
'''
Clean up on exit
'''
diff --git a/awx/lib/site-packages/celery/__compat__.py b/awx/lib/site-packages/celery/__compat__.py
index e09772346c..08700719cb 100644
--- a/awx/lib/site-packages/celery/__compat__.py
+++ b/awx/lib/site-packages/celery/__compat__.py
@@ -14,7 +14,12 @@ from __future__ import absolute_import
import operator
import sys
-from functools import reduce
+# import fails in python 2.5. fallback to reduce in stdlib
+try:
+ from functools import reduce
+except ImportError:
+ pass
+
from importlib import import_module
from types import ModuleType
diff --git a/awx/lib/site-packages/celery/__init__.py b/awx/lib/site-packages/celery/__init__.py
index 0aa88032cf..49a5ed9baf 100644
--- a/awx/lib/site-packages/celery/__init__.py
+++ b/awx/lib/site-packages/celery/__init__.py
@@ -8,7 +8,7 @@
from __future__ import absolute_import
SERIES = 'Chiastic Slide'
-VERSION = (3, 0, 19)
+VERSION = (3, 0, 22)
__version__ = '.'.join(str(p) for p in VERSION[0:3]) + ''.join(VERSION[3:])
__author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org'
diff --git a/awx/lib/site-packages/celery/app/builtins.py b/awx/lib/site-packages/celery/app/builtins.py
index 0c3582f5ed..1938af21cc 100644
--- a/awx/lib/site-packages/celery/app/builtins.py
+++ b/awx/lib/site-packages/celery/app/builtins.py
@@ -307,8 +307,9 @@ def add_chord_task(app):
accept_magic_kwargs = False
ignore_result = False
- def run(self, header, body, partial_args=(), interval=1, countdown=1,
- max_retries=None, propagate=None, eager=False, **kwargs):
+ def run(self, header, body, partial_args=(), interval=None,
+ countdown=1, max_retries=None, propagate=None,
+ eager=False, **kwargs):
propagate = default_propagate if propagate is None else propagate
group_id = uuid()
AsyncResult = self.app.AsyncResult
diff --git a/awx/lib/site-packages/celery/app/task.py b/awx/lib/site-packages/celery/app/task.py
index d3c0fe3522..94282ecbf2 100644
--- a/awx/lib/site-packages/celery/app/task.py
+++ b/awx/lib/site-packages/celery/app/task.py
@@ -198,11 +198,11 @@ class Task(object):
serializer = None
#: Hard time limit.
- #: Defaults to the :setting:`CELERY_TASK_TIME_LIMIT` setting.
+ #: Defaults to the :setting:`CELERYD_TASK_TIME_LIMIT` setting.
time_limit = None
#: Soft time limit.
- #: Defaults to the :setting:`CELERY_TASK_SOFT_TIME_LIMIT` setting.
+ #: Defaults to the :setting:`CELERYD_TASK_SOFT_TIME_LIMIT` setting.
soft_time_limit = None
#: The result store backend used for this task.
@@ -459,7 +459,8 @@ class Task(object):
args = (self.__self__, ) + tuple(args)
if conf.CELERY_ALWAYS_EAGER:
- return self.apply(args, kwargs, task_id=task_id, **options)
+ return self.apply(args, kwargs, task_id=task_id,
+ link=link, link_error=link_error, **options)
options = dict(extract_exec_options(self), **options)
options = router.route(options, self.name, args, kwargs)
@@ -580,7 +581,8 @@ class Task(object):
raise ret
return ret
- def apply(self, args=None, kwargs=None, **options):
+ def apply(self, args=None, kwargs=None,
+ link=None, link_error=None, **options):
"""Execute this task locally, by blocking until the task returns.
:param args: positional arguments passed on to the task.
@@ -614,6 +616,8 @@ class Task(object):
'is_eager': True,
'logfile': options.get('logfile'),
'loglevel': options.get('loglevel', 0),
+ 'callbacks': maybe_list(link),
+ 'errbacks': maybe_list(link_error),
'delivery_info': {'is_eager': True}}
if self.accept_magic_kwargs:
default_kwargs = {'task_name': task.name,
diff --git a/awx/lib/site-packages/celery/apps/worker.py b/awx/lib/site-packages/celery/apps/worker.py
index c3e6d6da42..2d6d67c19a 100644
--- a/awx/lib/site-packages/celery/apps/worker.py
+++ b/awx/lib/site-packages/celery/apps/worker.py
@@ -251,7 +251,7 @@ class Worker(configurated):
'version': VERSION_BANNER,
'conninfo': self.app.connection().as_uri(),
'concurrency': concurrency,
- 'platform': _platform.platform(),
+ 'platform': safe_str(_platform.platform()),
'events': events,
'queues': app.amqp.queues.format(indent=0, indent_first=False),
}).splitlines()
diff --git a/awx/lib/site-packages/celery/backends/mongodb.py b/awx/lib/site-packages/celery/backends/mongodb.py
index 2027d66b3b..30bbf673f3 100644
--- a/awx/lib/site-packages/celery/backends/mongodb.py
+++ b/awx/lib/site-packages/celery/backends/mongodb.py
@@ -114,6 +114,8 @@ class MongoBackend(BaseDictBackend):
if self._connection is not None:
# MongoDB connection will be closed automatically when object
# goes out of scope
+ del(self.collection)
+ del(self.database)
self._connection = None
def _store_result(self, task_id, result, status, traceback=None):
@@ -124,7 +126,7 @@ class MongoBackend(BaseDictBackend):
'date_done': datetime.utcnow(),
'traceback': Binary(self.encode(traceback)),
'children': Binary(self.encode(self.current_task_children()))}
- self.collection.save(meta, safe=True)
+ self.collection.save(meta)
return result
@@ -151,7 +153,7 @@ class MongoBackend(BaseDictBackend):
meta = {'_id': group_id,
'result': Binary(self.encode(result)),
'date_done': datetime.utcnow()}
- self.collection.save(meta, safe=True)
+ self.collection.save(meta)
return result
@@ -183,7 +185,7 @@ class MongoBackend(BaseDictBackend):
# By using safe=True, this will wait until it receives a response from
# the server. Likewise, it will raise an OperationsError if the
# response was unable to be completed.
- self.collection.remove({'_id': task_id}, safe=True)
+ self.collection.remove({'_id': task_id})
def cleanup(self):
"""Delete expired metadata."""
diff --git a/awx/lib/site-packages/celery/bin/celery.py b/awx/lib/site-packages/celery/bin/celery.py
index d92a497ed6..c0dd4a8e1e 100644
--- a/awx/lib/site-packages/celery/bin/celery.py
+++ b/awx/lib/site-packages/celery/bin/celery.py
@@ -614,7 +614,7 @@ class control(_RemoteControl):
def call(self, method, *args, **options):
# XXX Python 2.5 doesn't support X(*args, reply=True, **kwargs)
return getattr(self.app.control, method)(
- *args, **dict(options, retry=True))
+ *args, **dict(options, reply=True))
def pool_grow(self, method, n=1, **kwargs):
"""[N=1]"""
@@ -866,7 +866,7 @@ class CeleryCommand(BaseCommand):
cls = self.commands.get(command) or self.commands['help']
try:
return cls(app=self.app).run_from_argv(self.prog_name, argv)
- except (TypeError, Error):
+ except Error:
return self.execute('help', argv)
def remove_options_at_beginning(self, argv, index=0):
diff --git a/awx/lib/site-packages/celery/canvas.py b/awx/lib/site-packages/celery/canvas.py
index d8dcfabcad..0eb360f6de 100644
--- a/awx/lib/site-packages/celery/canvas.py
+++ b/awx/lib/site-packages/celery/canvas.py
@@ -267,7 +267,8 @@ class Signature(dict):
class chain(Signature):
def __init__(self, *tasks, **options):
- tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks
+ tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0])
+ else tasks)
Signature.__init__(
self, 'celery.chain', (), {'tasks': tasks}, **options
)
@@ -283,7 +284,7 @@ class chain(Signature):
tasks = d['kwargs']['tasks']
if d['args'] and tasks:
# partial args passed on to first task in chain (Issue #1057).
- tasks[0]['args'] = d['args'] + tasks[0]['args']
+ tasks[0]['args'] = tasks[0]._merge(d['args'])[0]
return chain(*d['kwargs']['tasks'], **kwdict(d['options']))
@property
@@ -392,7 +393,7 @@ class group(Signature):
if d['args'] and tasks:
# partial args passed on to all tasks in the group (Issue #1057).
for task in tasks:
- task['args'] = d['args'] + task['args']
+ task['args'] = task._merge(d['args'])[0]
return group(tasks, **kwdict(d['options']))
def __call__(self, *partial_args, **options):
diff --git a/awx/lib/site-packages/celery/concurrency/eventlet.py b/awx/lib/site-packages/celery/concurrency/eventlet.py
index fd97269f61..5f8d68750f 100644
--- a/awx/lib/site-packages/celery/concurrency/eventlet.py
+++ b/awx/lib/site-packages/celery/concurrency/eventlet.py
@@ -34,7 +34,8 @@ if not EVENTLET_NOPATCH and not PATCHED[0]:
import eventlet
import eventlet.debug
eventlet.monkey_patch()
- eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK)
+ if EVENTLET_DBLOCK:
+ eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK)
from time import time
diff --git a/awx/lib/site-packages/celery/contrib/migrate.py b/awx/lib/site-packages/celery/contrib/migrate.py
index 9e54979455..76fe1db2fa 100644
--- a/awx/lib/site-packages/celery/contrib/migrate.py
+++ b/awx/lib/site-packages/celery/contrib/migrate.py
@@ -236,6 +236,8 @@ def start_filter(app, conn, filter, limit=None, timeout=1.0,
consume_from=None, state=None, **kwargs):
state = state or State()
queues = prepare_queues(queues)
+ consume_from = [_maybe_queue(app, q)
+ for q in consume_from or queues.keys()]
if isinstance(tasks, basestring):
tasks = set(tasks.split(','))
if tasks is None:
diff --git a/awx/lib/site-packages/celery/datastructures.py b/awx/lib/site-packages/celery/datastructures.py
index f3e9c2e2b6..67f6c66b39 100644
--- a/awx/lib/site-packages/celery/datastructures.py
+++ b/awx/lib/site-packages/celery/datastructures.py
@@ -472,7 +472,10 @@ class LimitedSet(object):
if time.time() < item[0] + self.expires:
heappush(H, item)
break
- self._data.pop(item[1])
+ try:
+ self._data.pop(item[1])
+ except KeyError: # out of sync with heap
+ pass
i += 1
def update(self, other, heappush=heappush):
diff --git a/awx/lib/site-packages/celery/events/state.py b/awx/lib/site-packages/celery/events/state.py
index 0afbec5383..8c129c2a38 100644
--- a/awx/lib/site-packages/celery/events/state.py
+++ b/awx/lib/site-packages/celery/events/state.py
@@ -210,11 +210,20 @@ class State(object):
task_count = 0
def __init__(self, callback=None,
+ workers=None, tasks=None, taskheap=None,
max_workers_in_memory=5000, max_tasks_in_memory=10000):
- self.workers = LRUCache(limit=max_workers_in_memory)
- self.tasks = LRUCache(limit=max_tasks_in_memory)
self.event_callback = callback
+ self.workers = (LRUCache(max_workers_in_memory)
+ if workers is None else workers)
+ self.tasks = (LRUCache(max_tasks_in_memory)
+ if tasks is None else tasks)
+ self._taskheap = None # reserved for __reduce__ in 3.1
+ self.max_workers_in_memory = max_workers_in_memory
+ self.max_tasks_in_memory = max_tasks_in_memory
self._mutex = threading.Lock()
+ self.handlers = {'task': self.task_event,
+ 'worker': self.worker_event}
+ self._get_handler = self.handlers.__getitem__
def freeze_while(self, fun, *args, **kwargs):
clear_after = kwargs.pop('clear_after', False)
@@ -295,11 +304,14 @@ class State(object):
with self._mutex:
return self._dispatch_event(event)
- def _dispatch_event(self, event):
+ def _dispatch_event(self, event, kwdict=kwdict):
self.event_count += 1
event = kwdict(event)
group, _, subject = event['type'].partition('-')
- getattr(self, group + '_event')(subject, event)
+ try:
+ self._get_handler(group)(subject, event)
+ except KeyError:
+ pass
if self.event_callback:
self.event_callback(self, event)
@@ -356,14 +368,10 @@ class State(object):
return '<ClusterState: events=%s tasks=%s>' % (self.event_count,
self.task_count)
- def __getstate__(self):
- d = dict(vars(self))
- d.pop('_mutex')
- return d
-
- def __setstate__(self, state):
- self.__dict__ = state
- self._mutex = threading.Lock()
-
+ def __reduce__(self):
+ return self.__class__, (
+ self.event_callback, self.workers, self.tasks, None,
+ self.max_workers_in_memory, self.max_tasks_in_memory,
+ )
state = State()
diff --git a/awx/lib/site-packages/celery/loaders/__init__.py b/awx/lib/site-packages/celery/loaders/__init__.py
index 6f8aea72e9..1bd2baafcb 100644
--- a/awx/lib/site-packages/celery/loaders/__init__.py
+++ b/awx/lib/site-packages/celery/loaders/__init__.py
@@ -11,7 +11,7 @@ from __future__ import absolute_import
from celery._state import current_app
from celery.utils import deprecated
-from celery.utils.imports import symbol_by_name
+from celery.utils.imports import symbol_by_name, import_from_cwd
LOADER_ALIASES = {'app': 'celery.loaders.app:AppLoader',
'default': 'celery.loaders.default:Loader',
@@ -20,7 +20,7 @@ LOADER_ALIASES = {'app': 'celery.loaders.app:AppLoader',
def get_loader_cls(loader):
"""Get loader class by name/alias"""
- return symbol_by_name(loader, LOADER_ALIASES)
+ return symbol_by_name(loader, LOADER_ALIASES, imp=import_from_cwd)
@deprecated(deprecation='2.5', removal='4.0',
diff --git a/awx/lib/site-packages/celery/platforms.py b/awx/lib/site-packages/celery/platforms.py
index 7919cb10b8..3e6b0f5338 100644
--- a/awx/lib/site-packages/celery/platforms.py
+++ b/awx/lib/site-packages/celery/platforms.py
@@ -48,6 +48,12 @@ PIDFILE_MODE = ((os.R_OK | os.W_OK) << 6) | ((os.R_OK) << 3) | ((os.R_OK))
PIDLOCKED = """ERROR: Pidfile (%s) already exists.
Seems we're already running? (pid: %s)"""
+try:
+ from io import UnsupportedOperation
+ FILENO_ERRORS = (AttributeError, UnsupportedOperation)
+except ImportError: # Py2
+ FILENO_ERRORS = (AttributeError, ) # noqa
+
def pyimplementation():
"""Returns string identifying the current Python implementation."""
@@ -253,17 +259,21 @@ def _create_pidlock(pidfile):
def fileno(f):
- """Get object fileno, or :const:`None` if not defined."""
- if isinstance(f, int):
+ if isinstance(f, (int, long)):
return f
+ return f.fileno()
+
+
+def maybe_fileno(f):
+ """Get object fileno, or :const:`None` if not defined."""
try:
- return f.fileno()
- except AttributeError:
+ return fileno(f)
+ except FILENO_ERRORS:
pass
def close_open_fds(keep=None):
- keep = [fileno(f) for f in keep if fileno(f)] if keep else []
+ keep = [maybe_fileno(f) for f in keep if maybe_fileno(f)] if keep else []
for fd in reversed(range(get_fdmax(default=2048))):
if fd not in keep:
with ignore_errno(errno.EBADF):
@@ -299,7 +309,7 @@ class DaemonContext(object):
close_open_fds(self.stdfds)
for fd in self.stdfds:
- self.redirect_to_null(fileno(fd))
+ self.redirect_to_null(maybe_fileno(fd))
self._is_open = True
__enter__ = open
diff --git a/awx/lib/site-packages/celery/result.py b/awx/lib/site-packages/celery/result.py
index 1b6af3aec9..acf1ff6cbe 100644
--- a/awx/lib/site-packages/celery/result.py
+++ b/awx/lib/site-packages/celery/result.py
@@ -350,7 +350,7 @@ class ResultSet(ResultBase):
def failed(self):
"""Did any of the tasks fail?
- :returns: :const:`True` if any of the tasks failed.
+ :returns: :const:`True` if one of the tasks failed.
(i.e., raised an exception)
"""
@@ -359,7 +359,7 @@ class ResultSet(ResultBase):
def waiting(self):
"""Are any of the tasks incomplete?
- :returns: :const:`True` if any of the tasks is still
+ :returns: :const:`True` if one of the tasks are still
waiting for execution.
"""
@@ -368,7 +368,7 @@ class ResultSet(ResultBase):
def ready(self):
"""Did all of the tasks complete? (either by success of failure).
- :returns: :const:`True` if all of the tasks been
+ :returns: :const:`True` if all of the tasks has been
executed.
"""
@@ -435,7 +435,7 @@ class ResultSet(ResultBase):
time.sleep(interval)
elapsed += interval
if timeout and elapsed >= timeout:
- raise TimeoutError("The operation timed out")
+ raise TimeoutError('The operation timed out')
def get(self, timeout=None, propagate=True, interval=0.5):
"""See :meth:`join`
@@ -694,7 +694,7 @@ class EagerResult(AsyncResult):
self._state = states.REVOKED
def __repr__(self):
- return "<EagerResult: %s>" % self.id
+ return '<EagerResult: %s>' % self.id
@property
def result(self):
diff --git a/awx/lib/site-packages/celery/schedules.py b/awx/lib/site-packages/celery/schedules.py
index ca0e3aa314..cc7353d02c 100644
--- a/awx/lib/site-packages/celery/schedules.py
+++ b/awx/lib/site-packages/celery/schedules.py
@@ -379,7 +379,11 @@ class crontab(schedule):
flag = (datedata.dom == len(days_of_month) or
day_out_of_range(datedata.year,
months_of_year[datedata.moy],
- days_of_month[datedata.dom]))
+ days_of_month[datedata.dom]) or
+ (self.maybe_make_aware(datetime(datedata.year,
+ months_of_year[datedata.moy],
+ days_of_month[datedata.dom])) < last_run_at))
+
if flag:
datedata.dom = 0
datedata.moy += 1
@@ -449,10 +453,11 @@ class crontab(schedule):
self._orig_day_of_month,
self._orig_month_of_year), None)
- def remaining_estimate(self, last_run_at, tz=None):
+ def remaining_delta(self, last_run_at, tz=None):
"""Returns when the periodic task should run next as a timedelta."""
tz = tz or self.tz
last_run_at = self.maybe_make_aware(last_run_at)
+ now = self.maybe_make_aware(self.now())
dow_num = last_run_at.isoweekday() % 7 # Sunday is day 0, not day 7
execute_this_date = (last_run_at.month in self.month_of_year and
@@ -460,6 +465,9 @@ class crontab(schedule):
dow_num in self.day_of_week)
execute_this_hour = (execute_this_date and
+ last_run_at.day == now.day and
+ last_run_at.month == now.month and
+ last_run_at.year == now.year and
last_run_at.hour in self.hour and
last_run_at.minute < max(self.minute))
@@ -499,10 +507,11 @@ class crontab(schedule):
else:
delta = self._delta_to_next(last_run_at,
next_hour, next_minute)
+ return self.to_local(last_run_at), delta, self.to_local(now)
- now = self.maybe_make_aware(self.now())
- return remaining(self.to_local(last_run_at), delta,
- self.to_local(now))
+ def remaining_estimate(self, last_run_at):
+ """Returns when the periodic task should run next as a timedelta."""
+ return remaining(*self.remaining_delta(last_run_at))
def is_due(self, last_run_at):
"""Returns tuple of two items `(is_due, next_time_to_run)`,
diff --git a/awx/lib/site-packages/celery/task/trace.py b/awx/lib/site-packages/celery/task/trace.py
index ab1a4d3e08..6a2a3bef88 100644
--- a/awx/lib/site-packages/celery/task/trace.py
+++ b/awx/lib/site-packages/celery/task/trace.py
@@ -30,7 +30,10 @@ from celery.app import set_default_app
from celery.app.task import Task as BaseTask, Context
from celery.datastructures import ExceptionInfo
from celery.exceptions import Ignore, RetryTaskError
-from celery.utils.serialization import get_pickleable_exception
+from celery.utils.serialization import (
+ get_pickleable_exception,
+ get_pickleable_etype,
+)
from celery.utils.log import get_logger
_logger = get_logger(__name__)
@@ -128,7 +131,9 @@ class TraceInfo(object):
type_, _, tb = sys.exc_info()
try:
exc = self.retval
- einfo = ExceptionInfo((type_, get_pickleable_exception(exc), tb))
+ einfo = ExceptionInfo()
+ einfo.exception = get_pickleable_exception(einfo.exception)
+ einfo.type = get_pickleable_etype(einfo.type)
if store_errors:
task.backend.mark_as_failure(req.id, exc, einfo.traceback)
task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
diff --git a/awx/lib/site-packages/celery/tests/backends/test_base.py b/awx/lib/site-packages/celery/tests/backends/test_base.py
index cae67425c8..b04919c3b4 100644
--- a/awx/lib/site-packages/celery/tests/backends/test_base.py
+++ b/awx/lib/site-packages/celery/tests/backends/test_base.py
@@ -11,8 +11,7 @@ from celery import current_app
from celery.result import AsyncResult, GroupResult
from celery.utils import serialization
from celery.utils.serialization import subclass_exception
-from celery.utils.serialization import \
- find_nearest_pickleable_exception as fnpe
+from celery.utils.serialization import find_pickleable_exception as fnpe
from celery.utils.serialization import UnpickleableExceptionWrapper
from celery.utils.serialization import get_pickleable_exception as gpe
diff --git a/awx/lib/site-packages/celery/tests/backends/test_mongodb.py b/awx/lib/site-packages/celery/tests/backends/test_mongodb.py
index 8980176581..3c15ab40db 100644
--- a/awx/lib/site-packages/celery/tests/backends/test_mongodb.py
+++ b/awx/lib/site-packages/celery/tests/backends/test_mongodb.py
@@ -287,7 +287,7 @@ class test_MongoBackend(AppCase):
mock_database.__getitem__.assert_called_once_with(
MONGODB_COLLECTION)
mock_collection.remove.assert_called_once_with(
- {'_id': sentinel.task_id}, safe=True)
+ {'_id': sentinel.task_id})
@patch('celery.backends.mongodb.MongoBackend._get_database')
def test_cleanup(self, mock_get_database):
diff --git a/awx/lib/site-packages/celery/tests/tasks/test_canvas.py b/awx/lib/site-packages/celery/tests/tasks/test_canvas.py
index 8011a80de8..a56033ddd2 100644
--- a/awx/lib/site-packages/celery/tests/tasks/test_canvas.py
+++ b/awx/lib/site-packages/celery/tests/tasks/test_canvas.py
@@ -134,6 +134,11 @@ class test_chain(Case):
self.assertEqual(res.parent.parent.get(), 8)
self.assertIsNone(res.parent.parent.parent)
+ def test_accepts_generator_argument(self):
+ x = chain(add.s(i) for i in range(10))
+ self.assertTrue(x.tasks[0].type, add)
+ self.assertTrue(x.type)
+
class test_group(Case):
diff --git a/awx/lib/site-packages/celery/tests/tasks/test_tasks.py b/awx/lib/site-packages/celery/tests/tasks/test_tasks.py
index dd53e50faf..65e6e77862 100644
--- a/awx/lib/site-packages/celery/tests/tasks/test_tasks.py
+++ b/awx/lib/site-packages/celery/tests/tasks/test_tasks.py
@@ -1,6 +1,8 @@
from __future__ import absolute_import
from __future__ import with_statement
+import time
+
from datetime import datetime, timedelta
from functools import wraps
from mock import patch
@@ -616,6 +618,14 @@ def monthly():
pass
+@periodic_task(run_every=crontab(hour=22,
+ day_of_week='*',
+ month_of_year='2',
+ day_of_month='26,27,28'))
+def monthly_moy():
+ pass
+
+
@periodic_task(run_every=crontab(hour=7, minute=30,
day_of_week='thursday',
day_of_month='8-14',
@@ -1212,6 +1222,40 @@ class test_crontab_is_due(Case):
self.assertFalse(due)
self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60)
+ @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
+ def test_monthly_moy_execution_is_due(self):
+ due, remaining = monthly_moy.run_every.is_due(
+ datetime(2013, 7, 4, 10, 0))
+ self.assertTrue(due)
+ self.assertEqual(remaining, 60.)
+
+ @patch_crontab_nowfun(monthly_moy, datetime(2013, 6, 28, 14, 30))
+ def test_monthly_moy_execution_is_not_due(self):
+ due, remaining = monthly_moy.run_every.is_due(
+ datetime(2013, 6, 28, 22, 14))
+ self.assertFalse(due)
+ attempt = (
+ time.mktime(datetime(2014, 2, 26, 22, 0).timetuple()) -
+ time.mktime(datetime(2013, 6, 28, 14, 30).timetuple()) -
+ 60 * 60
+ )
+ self.assertEqual(remaining, attempt)
+
+ @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
+ def test_monthly_moy_execution_is_due2(self):
+ due, remaining = monthly_moy.run_every.is_due(
+ datetime(2013, 2, 28, 10, 0))
+ self.assertTrue(due)
+ self.assertEqual(remaining, 60.)
+
+ @patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 21, 0))
+ def test_monthly_moy_execution_is_not_due2(self):
+ due, remaining = monthly_moy.run_every.is_due(
+ datetime(2013, 6, 28, 22, 14))
+ self.assertFalse(due)
+ attempt = 60 * 60
+ self.assertEqual(remaining, attempt)
+
@patch_crontab_nowfun(yearly, datetime(2010, 3, 11, 7, 30))
def test_yearly_execution_is_due(self):
due, remaining = yearly.run_every.is_due(
diff --git a/awx/lib/site-packages/celery/tests/utilities/test_encoding.py b/awx/lib/site-packages/celery/tests/utilities/test_encoding.py
index a7bd28a411..d9885b857f 100644
--- a/awx/lib/site-packages/celery/tests/utilities/test_encoding.py
+++ b/awx/lib/site-packages/celery/tests/utilities/test_encoding.py
@@ -1,9 +1,5 @@
from __future__ import absolute_import
-import sys
-
-from nose import SkipTest
-
from celery.utils import encoding
from celery.tests.utils import Case
@@ -15,17 +11,6 @@ class test_encoding(Case):
self.assertTrue(encoding.safe_str('foo'))
self.assertTrue(encoding.safe_str(u'foo'))
- def test_safe_str_UnicodeDecodeError(self):
- if sys.version_info >= (3, 0):
- raise SkipTest('py3k: not relevant')
-
- class foo(unicode):
-
- def encode(self, *args, **kwargs):
- raise UnicodeDecodeError('foo')
-
- self.assertIn('<Unrepresentable', encoding.safe_str(foo()))
-
def test_safe_repr(self):
self.assertTrue(encoding.safe_repr(object()))
diff --git a/awx/lib/site-packages/celery/tests/utilities/test_term.py b/awx/lib/site-packages/celery/tests/utilities/test_term.py
index ce1285701d..5e6ca50631 100644
--- a/awx/lib/site-packages/celery/tests/utilities/test_term.py
+++ b/awx/lib/site-packages/celery/tests/utilities/test_term.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
+from kombu.utils.encoding import str_t
+
from celery.utils import term
from celery.utils.term import colored, fg
@@ -38,7 +40,7 @@ class test_colored(Case):
self.assertTrue(str(colored().iwhite('f')))
self.assertTrue(str(colored().reset('f')))
- self.assertTrue(str(colored().green(u'∂bar')))
+ self.assertTrue(str_t(colored().green(u'∂bar')))
self.assertTrue(
colored().red(u'éefoo') + colored().green(u'∂bar'))
diff --git a/awx/lib/site-packages/celery/tests/worker/test_worker.py b/awx/lib/site-packages/celery/tests/worker/test_worker.py
index 006488dbf5..044f033073 100644
--- a/awx/lib/site-packages/celery/tests/worker/test_worker.py
+++ b/awx/lib/site-packages/celery/tests/worker/test_worker.py
@@ -967,10 +967,12 @@ class test_WorkController(AppCase):
except ImportError:
raise SkipTest('multiprocessing not supported')
self.assertIsInstance(worker.ready_queue, AsyncTaskBucket)
- self.assertFalse(worker.mediator)
- self.assertNotEqual(worker.ready_queue.put, worker.process_task)
+ # XXX disabled until 3.1
+ #self.assertFalse(worker.mediator)
+ #self.assertNotEqual(worker.ready_queue.put, worker.process_task)
def test_disable_rate_limits_processes(self):
+ raise SkipTest('disabled until v3.1')
try:
worker = self.create_worker(disable_rate_limits=True,
use_eventloop=False,
@@ -1058,6 +1060,7 @@ class test_WorkController(AppCase):
self.assertTrue(w.disable_rate_limits)
def test_Queues_pool_no_sem(self):
+ raise SkipTest('disabled until v3.1')
w = Mock()
w.pool_cls.uses_semaphore = False
Queues(w).create(w)
@@ -1086,6 +1089,7 @@ class test_WorkController(AppCase):
w.hub.on_init = []
w.pool_cls = Mock()
P = w.pool_cls.return_value = Mock()
+ P._cache = {}
P.timers = {Mock(): 30}
w.use_eventloop = True
w.consumer.restart_count = -1
@@ -1105,23 +1109,13 @@ class test_WorkController(AppCase):
cbs['on_process_down'](w)
hub.remove.assert_called_with(w.sentinel)
+ w.pool._tref_for_id = {}
+
result = Mock()
- tref = result._tref
cbs['on_timeout_cancel'](result)
- tref.cancel.assert_called_with()
cbs['on_timeout_cancel'](result) # no more tref
- cbs['on_timeout_set'](result, 10, 20)
- tsoft, callback = hub.timer.apply_after.call_args[0]
- callback()
-
- cbs['on_timeout_set'](result, 10, None)
- tsoft, callback = hub.timer.apply_after.call_args[0]
- callback()
- cbs['on_timeout_set'](result, None, 10)
- cbs['on_timeout_set'](result, None, None)
-
with self.assertRaises(WorkerLostError):
P.did_start_ok.return_value = False
w.consumer.restart_count = 0
diff --git a/awx/lib/site-packages/celery/utils/imports.py b/awx/lib/site-packages/celery/utils/imports.py
index 0ea1d77523..e46462663e 100644
--- a/awx/lib/site-packages/celery/utils/imports.py
+++ b/awx/lib/site-packages/celery/utils/imports.py
@@ -28,15 +28,18 @@ class NotAPackage(Exception):
if sys.version_info >= (3, 3): # pragma: no cover
def qualname(obj):
- return obj.__qualname__
-
+ if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
+ obj = obj.__class__
+ q = getattr(obj, '__qualname__', None)
+ if '.' not in q:
+ q = '.'.join((obj.__module__, q))
+ return q
else:
def qualname(obj): # noqa
if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
return qualname(obj.__class__)
-
- return '%s.%s' % (obj.__module__, obj.__name__)
+ return '.'.join((obj.__module__, obj.__name__))
def instantiate(name, *args, **kwargs):
diff --git a/awx/lib/site-packages/celery/utils/serialization.py b/awx/lib/site-packages/celery/utils/serialization.py
index 0a7cdf25c7..0cce7e971a 100644
--- a/awx/lib/site-packages/celery/utils/serialization.py
+++ b/awx/lib/site-packages/celery/utils/serialization.py
@@ -50,7 +50,8 @@ else:
return type(name, (parent,), {'__module__': module})
-def find_nearest_pickleable_exception(exc):
+def find_pickleable_exception(exc, loads=pickle.loads,
+ dumps=pickle.dumps):
"""With an exception instance, iterate over its super classes (by mro)
and find the first super exception that is pickleable. It does
not go below :exc:`Exception` (i.e. it skips :exc:`Exception`,
@@ -65,7 +66,19 @@ def find_nearest_pickleable_exception(exc):
:rtype :exc:`Exception`:
"""
- cls = exc.__class__
+ exc_args = getattr(exc, 'args', [])
+ for supercls in itermro(exc.__class__, unwanted_base_classes):
+ try:
+ superexc = supercls(*exc_args)
+ loads(dumps(superexc))
+ except:
+ pass
+ else:
+ return superexc
+find_nearest_pickleable_exception = find_pickleable_exception # XXX compat
+
+
+def itermro(cls, stop):
getmro_ = getattr(cls, 'mro', None)
# old-style classes doesn't have mro()
@@ -77,18 +90,11 @@ def find_nearest_pickleable_exception(exc):
getmro_ = lambda: inspect.getmro(cls)
for supercls in getmro_():
- if supercls in unwanted_base_classes:
+ if supercls in stop:
# only BaseException and object, from here on down,
# we don't care about these.
return
- try:
- exc_args = getattr(exc, 'args', [])
- superexc = supercls(*exc_args)
- pickle.loads(pickle.dumps(superexc))
- except:
- pass
- else:
- return superexc
+ yield supercls
def create_exception_cls(name, module, parent=None):
@@ -165,12 +171,19 @@ def get_pickleable_exception(exc):
pass
else:
return exc
- nearest = find_nearest_pickleable_exception(exc)
+ nearest = find_pickleable_exception(exc)
if nearest:
return nearest
return UnpickleableExceptionWrapper.from_exception(exc)
+def get_pickleable_etype(cls, loads=pickle.loads, dumps=pickle.dumps):
+ try:
+ loads(dumps(cls))
+ except:
+ return Exception
+
+
def get_pickled_exception(exc):
"""Get original exception from exception pickled using
:meth:`get_pickleable_exception`."""
diff --git a/awx/lib/site-packages/celery/utils/threads.py b/awx/lib/site-packages/celery/utils/threads.py
index 8f88aabf84..9826363491 100644
--- a/awx/lib/site-packages/celery/utils/threads.py
+++ b/awx/lib/site-packages/celery/utils/threads.py
@@ -18,29 +18,26 @@ from celery.utils.compat import THREAD_TIMEOUT_MAX
USE_FAST_LOCALS = os.environ.get('USE_FAST_LOCALS')
PY3 = sys.version_info[0] == 3
+NEW_EVENT = (sys.version_info[0] == 3) and (sys.version_info[1] >= 3)
_Thread = threading.Thread
-_Event = threading.Event if PY3 else threading._Event
+_Event = threading.Event if NEW_EVENT else threading._Event
active_count = (getattr(threading, 'active_count', None) or
threading.activeCount)
-class Event(_Event):
+if sys.version_info < (2, 6):
- if not hasattr(_Event, 'is_set'): # pragma: no cover
+ class Event(_Event): # pragma: no cover
is_set = _Event.isSet
-
-class Thread(_Thread):
-
- if not hasattr(_Thread, 'is_alive'): # pragma: no cover
+ class Thread(_Thread): # pragma: no cover
is_alive = _Thread.isAlive
-
- if not hasattr(_Thread, 'daemon'): # pragma: no cover
daemon = property(_Thread.isDaemon, _Thread.setDaemon)
-
- if not hasattr(_Thread, 'name'): # pragma: no cover
name = property(_Thread.getName, _Thread.setName)
+else:
+ Event = _Event
+ Thread = _Thread
class bgThread(Thread):
diff --git a/awx/lib/site-packages/celery/worker/__init__.py b/awx/lib/site-packages/celery/worker/__init__.py
index 10faa394de..04f7bcd2e4 100644
--- a/awx/lib/site-packages/celery/worker/__init__.py
+++ b/awx/lib/site-packages/celery/worker/__init__.py
@@ -19,6 +19,7 @@ import time
import traceback
from functools import partial
+from weakref import WeakValueDictionary
from billiard.exceptions import WorkerLostError
from billiard.util import Finalize
@@ -26,6 +27,7 @@ from kombu.syn import detect_environment
from celery import concurrency as _concurrency
from celery import platforms
+from celery import signals
from celery.app import app_or_default
from celery.app.abstract import configurated, from_config
from celery.exceptions import SystemTerminate, TaskRevokedError
@@ -105,6 +107,7 @@ class Pool(bootsteps.StartStopComponent):
add_reader = hub.add_reader
remove = hub.remove
now = time.time
+ cache = pool._pool._cache
# did_start_ok will verify that pool processes were able to start,
# but this will only work the first time we start, as
@@ -120,25 +123,58 @@ class Pool(bootsteps.StartStopComponent):
for handler, interval in pool.timers.iteritems():
hub.timer.apply_interval(interval * 1000.0, handler)
- def on_timeout_set(R, soft, hard):
+ trefs = pool._tref_for_id = WeakValueDictionary()
+
+ def _discard_tref(job):
+ try:
+ tref = trefs.pop(job)
+ tref.cancel()
+ del(tref)
+ except (KeyError, AttributeError):
+ pass # out of scope
+
+ def _on_hard_timeout(job):
+ try:
+ result = cache[job]
+ except KeyError:
+ pass # job ready
+ else:
+ on_hard_timeout(result)
+ finally:
+ # remove tref
+ _discard_tref(job)
+
+ def _on_soft_timeout(job, soft, hard, hub):
+ if hard:
+ trefs[job] = apply_at(
+ now() + (hard - soft),
+ _on_hard_timeout, (job, ),
+ )
+ try:
+ result = cache[job]
+ except KeyError:
+ pass # job ready
+ else:
+ on_soft_timeout(result)
+ finally:
+ if not hard:
+ # remove tref
+ _discard_tref(job)
- def _on_soft_timeout():
- if hard:
- R._tref = apply_at(now() + (hard - soft),
- on_hard_timeout, (R, ))
- on_soft_timeout(R)
+ def on_timeout_set(R, soft, hard):
if soft:
- R._tref = apply_after(soft * 1000.0, _on_soft_timeout)
+ trefs[R._job] = apply_after(
+ soft * 1000.0,
+ _on_soft_timeout, (R._job, soft, hard, hub),
+ )
elif hard:
- R._tref = apply_after(hard * 1000.0,
- on_hard_timeout, (R, ))
+ trefs[R._job] = apply_after(
+ hard * 1000.0,
+ _on_hard_timeout, (R._job, )
+ )
- def on_timeout_cancel(result):
- try:
- result._tref.cancel()
- delattr(result, '_tref')
- except AttributeError:
- pass
+ def on_timeout_cancel(R):
+ _discard_tref(R._job)
pool.init_callbacks(
on_process_up=lambda w: add_reader(w.sentinel, maintain_pool),
@@ -208,19 +244,18 @@ class Queues(bootsteps.Component):
def create(self, w):
BucketType = TaskBucket
- w.start_mediator = not w.disable_rate_limits
+ w.start_mediator = w.pool_cls.requires_mediator
if not w.pool_cls.rlimit_safe:
- w.start_mediator = False
BucketType = AsyncTaskBucket
process_task = w.process_task
if w.use_eventloop:
- w.start_mediator = False
BucketType = AsyncTaskBucket
if w.pool_putlocks and w.pool_cls.uses_semaphore:
process_task = w.process_task_sem
- if w.disable_rate_limits:
+ if w.disable_rate_limits or not w.start_mediator:
w.ready_queue = FastQueue()
- w.ready_queue.put = process_task
+ if not w.start_mediator:
+ w.ready_queue.put = process_task
else:
w.ready_queue = BucketType(
task_registry=w.app.tasks, callback=process_task, worker=w,
@@ -327,7 +362,10 @@ class WorkController(configurated):
self.loglevel = loglevel or self.loglevel
self.hostname = hostname or socket.gethostname()
self.ready_callback = ready_callback
- self._finalize = Finalize(self, self.stop, exitpriority=1)
+ self._finalize = [
+ Finalize(self, self.stop, exitpriority=1),
+ Finalize(self, self._send_worker_shutdown, exitpriority=10),
+ ]
self.pidfile = pidfile
self.pidlock = None
# this connection is not established, only used for params
@@ -350,6 +388,9 @@ class WorkController(configurated):
self.components = []
self.namespace = Namespace(app=self.app).apply(self, **kwargs)
+ def _send_worker_shutdown(self):
+ signals.worker_shutdown.send(sender=self)
+
def start(self):
"""Starts the workers main loop."""
self._state = self.RUN
diff --git a/awx/lib/site-packages/celery/worker/buckets.py b/awx/lib/site-packages/celery/worker/buckets.py
index e975328d49..2d7ccc3081 100644
--- a/awx/lib/site-packages/celery/worker/buckets.py
+++ b/awx/lib/site-packages/celery/worker/buckets.py
@@ -39,6 +39,12 @@ class AsyncTaskBucket(object):
self.worker = worker
self.buckets = {}
self.refresh()
+ self._queue = Queue()
+ self._quick_put = self._queue.put
+ self.get = self._queue.get
+
+ def get(self, *args, **kwargs):
+ return self._queue.get(*args, **kwargs)
def cont(self, request, bucket, tokens):
if not bucket.can_consume(tokens):
@@ -47,7 +53,7 @@ class AsyncTaskBucket(object):
hold * 1000.0, self.cont, (request, bucket, tokens),
)
else:
- self.callback(request)
+ self._quick_put(request)
def put(self, request):
name = request.name
@@ -56,7 +62,7 @@ class AsyncTaskBucket(object):
except KeyError:
bucket = self.add_bucket_for_type(name)
if not bucket:
- return self.callback(request)
+ return self._quick_put(request)
return self.cont(request, bucket, 1)
def add_task_type(self, name):
diff --git a/awx/lib/site-packages/celery/worker/consumer.py b/awx/lib/site-packages/celery/worker/consumer.py
index aed01f3e02..4d811ffed2 100644
--- a/awx/lib/site-packages/celery/worker/consumer.py
+++ b/awx/lib/site-packages/celery/worker/consumer.py
@@ -80,6 +80,8 @@ import threading
from time import sleep
from Queue import Empty
+from billiard.common import restart_state
+from billiard.exceptions import RestartFreqExceeded
from kombu.syn import _detect_environment
from kombu.utils.encoding import safe_repr, safe_str, bytes_t
from kombu.utils.eventio import READ, WRITE, ERR
@@ -100,6 +102,13 @@ from .bootsteps import StartStopComponent
from .control import Panel
from .heartbeat import Heart
+try:
+ buffer_t = buffer
+except NameError: # pragma: no cover
+
+ class buffer_t(object): # noqa
+ pass
+
RUN = 0x1
CLOSE = 0x2
@@ -171,7 +180,7 @@ def debug(msg, *args, **kwargs):
def dump_body(m, body):
- if isinstance(body, buffer):
+ if isinstance(body, buffer_t):
body = bytes_t(body)
return "%s (%sb)" % (text.truncate(safe_repr(body), 1024), len(m.body))
@@ -348,6 +357,7 @@ class Consumer(object):
conninfo = self.app.connection()
self.connection_errors = conninfo.connection_errors
self.channel_errors = conninfo.channel_errors
+ self._restart_state = restart_state(maxR=5, maxT=1)
self._does_info = logger.isEnabledFor(logging.INFO)
self.strategies = {}
@@ -391,6 +401,11 @@ class Consumer(object):
self.restart_count += 1
self.maybe_shutdown()
try:
+ self._restart_state.step()
+ except RestartFreqExceeded as exc:
+ crit('Frequent restarts detected: %r', exc, exc_info=1)
+ sleep(1)
+ try:
self.reset_connection()
self.consume_messages()
except self.connection_errors + self.channel_errors:
@@ -736,6 +751,7 @@ class Consumer(object):
# to the current channel.
self.ready_queue.clear()
self.timer.clear()
+ state.reserved_requests.clear()
# Re-establish the broker connection and setup the task consumer.
self.connection = self._open_connection()
diff --git a/awx/lib/site-packages/celery/worker/mediator.py b/awx/lib/site-packages/celery/worker/mediator.py
index b467b71c74..0e10392797 100644
--- a/awx/lib/site-packages/celery/worker/mediator.py
+++ b/awx/lib/site-packages/celery/worker/mediator.py
@@ -36,7 +36,7 @@ class WorkerComponent(StartStopComponent):
w.mediator = None
def include_if(self, w):
- return w.start_mediator and not w.use_eventloop
+ return w.start_mediator
def create(self, w):
m = w.mediator = self.instantiate(w.mediator_cls, w.ready_queue,
diff --git a/awx/lib/site-packages/django_extensions/__init__.py b/awx/lib/site-packages/django_extensions/__init__.py
index 3efe475a76..3426dc3739 100644
--- a/awx/lib/site-packages/django_extensions/__init__.py
+++ b/awx/lib/site-packages/django_extensions/__init__.py
@@ -1,5 +1,5 @@
-VERSION = (1, 1, 1)
+VERSION = (1, 2, 0)
# Dynamically calculate the version based on VERSION tuple
if len(VERSION) > 2 and VERSION[2] is not None:
diff --git a/awx/lib/site-packages/django_extensions/admin/__init__.py b/awx/lib/site-packages/django_extensions/admin/__init__.py
index 4df241460b..7564ea65e3 100644
--- a/awx/lib/site-packages/django_extensions/admin/__init__.py
+++ b/awx/lib/site-packages/django_extensions/admin/__init__.py
@@ -9,6 +9,7 @@
# (Michal Salaban)
#
+import six
import operator
from six.moves import reduce
from django.http import HttpResponse, HttpResponseNotFound
@@ -108,8 +109,7 @@ class ForeignKeyAutocompleteAdmin(ModelAdmin):
other_qs.dup_select_related(queryset)
other_qs = other_qs.filter(reduce(operator.or_, or_queries))
queryset = queryset & other_qs
- data = ''.join([u'%s|%s\n' % (
- to_string_function(f), f.pk) for f in queryset])
+ data = ''.join([six.u('%s|%s\n' % (to_string_function(f), f.pk)) for f in queryset])
elif object_pk:
try:
obj = queryset.get(pk=object_pk)
@@ -139,7 +139,7 @@ class ForeignKeyAutocompleteAdmin(ModelAdmin):
model_name = db_field.rel.to._meta.object_name
help_text = self.get_help_text(db_field.name, model_name)
if kwargs.get('help_text'):
- help_text = u'%s %s' % (kwargs['help_text'], help_text)
+ help_text = six.u('%s %s' % (kwargs['help_text'], help_text))
kwargs['widget'] = ForeignKeySearchInput(db_field.rel, self.related_search_fields[db_field.name])
kwargs['help_text'] = help_text
return super(ForeignKeyAutocompleteAdmin, self).formfield_for_dbfield(db_field, **kwargs)
diff --git a/awx/lib/site-packages/django_extensions/admin/widgets.py b/awx/lib/site-packages/django_extensions/admin/widgets.py
index feefd406bc..6a2a4af959 100644
--- a/awx/lib/site-packages/django_extensions/admin/widgets.py
+++ b/awx/lib/site-packages/django_extensions/admin/widgets.py
@@ -1,3 +1,4 @@
+import six
import django
from django import forms
from django.conf import settings
@@ -67,7 +68,7 @@ class ForeignKeySearchInput(ForeignKeyRawIdWidget):
if value:
label = self.label_for_value(value)
else:
- label = u''
+ label = six.u('')
try:
admin_media_prefix = settings.ADMIN_MEDIA_PREFIX
@@ -92,4 +93,4 @@ class ForeignKeySearchInput(ForeignKeyRawIdWidget):
'django_extensions/widgets/foreignkey_searchinput.html',
), context))
output.reverse()
- return mark_safe(u''.join(output))
+ return mark_safe(six.u(''.join(output)))
diff --git a/awx/lib/site-packages/django_extensions/db/fields/__init__.py b/awx/lib/site-packages/django_extensions/db/fields/__init__.py
index 8a17c48064..337ddd2c27 100644
--- a/awx/lib/site-packages/django_extensions/db/fields/__init__.py
+++ b/awx/lib/site-packages/django_extensions/db/fields/__init__.py
@@ -3,13 +3,13 @@ Django Extensions additional model fields
"""
import re
import six
-
try:
import uuid
- assert uuid
+ HAS_UUID = True
except ImportError:
- from django_extensions.utils import uuid
+ HAS_UUID = False
+from django.core.exceptions import ImproperlyConfigured
from django.template.defaultfilters import slugify
from django.db.models import DateTimeField, CharField, SlugField
@@ -56,7 +56,7 @@ class AutoSlugField(SlugField):
raise ValueError("missing 'populate_from' argument")
else:
self._populate_from = populate_from
- self.separator = kwargs.pop('separator', u'-')
+ self.separator = kwargs.pop('separator', six.u('-'))
self.overwrite = kwargs.pop('overwrite', False)
self.allow_duplicates = kwargs.pop('allow_duplicates', False)
super(AutoSlugField, self).__init__(*args, **kwargs)
@@ -221,13 +221,15 @@ class UUIDVersionError(Exception):
class UUIDField(CharField):
""" UUIDField
- By default uses UUID version 4 (generate from host ID, sequence number and current time)
+ By default uses UUID version 4 (randomly generated UUID).
- The field support all uuid versions which are natively supported by the uuid python module.
+ The field support all uuid versions which are natively supported by the uuid python module, except version 2.
For more information see: http://docs.python.org/lib/module-uuid.html
"""
- def __init__(self, verbose_name=None, name=None, auto=True, version=1, node=None, clock_seq=None, namespace=None, **kwargs):
+ def __init__(self, verbose_name=None, name=None, auto=True, version=4, node=None, clock_seq=None, namespace=None, **kwargs):
+ if not HAS_UUID:
+ raise ImproperlyConfigured("'uuid' module is required for UUIDField. (Do you have Python 2.5 or higher installed ?)")
kwargs.setdefault('max_length', 36)
if auto:
self.empty_strings_allowed = False
@@ -244,17 +246,6 @@ class UUIDField(CharField):
def get_internal_type(self):
return CharField.__name__
- def contribute_to_class(self, cls, name):
- if self.primary_key:
- assert not cls._meta.has_auto_field, "A model can't have more than one AutoField: %s %s %s; have %s" % (
- self, cls, name, cls._meta.auto_field
- )
- super(UUIDField, self).contribute_to_class(cls, name)
- cls._meta.has_auto_field = True
- cls._meta.auto_field = self
- else:
- super(UUIDField, self).contribute_to_class(cls, name)
-
def create_uuid(self):
if not self.version or self.version == 4:
return uuid.uuid4()
@@ -277,7 +268,7 @@ class UUIDField(CharField):
return value
else:
if self.auto and not value:
- value = six.u(self.create_uuid())
+ value = force_unicode(self.create_uuid())
setattr(model_instance, self.attname, value)
return value
diff --git a/awx/lib/site-packages/django_extensions/db/fields/encrypted.py b/awx/lib/site-packages/django_extensions/db/fields/encrypted.py
index 5d45b6912a..f897598a96 100644
--- a/awx/lib/site-packages/django_extensions/db/fields/encrypted.py
+++ b/awx/lib/site-packages/django_extensions/db/fields/encrypted.py
@@ -21,8 +21,11 @@ class BaseEncryptedField(models.Field):
def __init__(self, *args, **kwargs):
if not hasattr(settings, 'ENCRYPTED_FIELD_KEYS_DIR'):
- raise ImproperlyConfigured('You must set the ENCRYPTED_FIELD_KEYS_DIR setting to your Keyczar keys directory.')
- self.crypt = keyczar.Crypter.Read(settings.ENCRYPTED_FIELD_KEYS_DIR)
+ raise ImproperlyConfigured('You must set the ENCRYPTED_FIELD_KEYS_DIR '
+ 'setting to your Keyczar keys directory.')
+
+ crypt_class = self.get_crypt_class()
+ self.crypt = crypt_class.Read(settings.ENCRYPTED_FIELD_KEYS_DIR)
# Encrypted size is larger than unencrypted
self.unencrypted_length = max_length = kwargs.get('max_length', None)
@@ -34,6 +37,32 @@ class BaseEncryptedField(models.Field):
super(BaseEncryptedField, self).__init__(*args, **kwargs)
+ def get_crypt_class(self):
+ """
+ Get the Keyczar class to use.
+
+ The class can be customized with the ENCRYPTED_FIELD_MODE setting. By default,
+ this setting is DECRYPT_AND_ENCRYPT. Set this to ENCRYPT to disable decryption.
+ This is necessary if you are only providing public keys to Keyczar.
+
+ Returns:
+ keyczar.Encrypter if ENCRYPTED_FIELD_MODE is ENCRYPT.
+ keyczar.Crypter if ENCRYPTED_FIELD_MODE is DECRYPT_AND_ENCRYPT.
+
+ Override this method to customize the type of Keyczar class returned.
+ """
+
+ crypt_type = getattr(settings, 'ENCRYPTED_FIELD_MODE', 'DECRYPT_AND_ENCRYPT')
+ if crypt_type == 'ENCRYPT':
+ crypt_class_name = 'Encrypter'
+ elif crypt_type == 'DECRYPT_AND_ENCRYPT':
+ crypt_class_name = 'Crypter'
+ else:
+ raise ImproperlyConfigured(
+ 'ENCRYPTED_FIELD_MODE must be either DECRYPT_AND_ENCRYPT '
+ 'or ENCRYPT, not %s.' % crypt_type)
+ return getattr(keyczar, crypt_class_name)
+
def to_python(self, value):
if isinstance(self.crypt.primary_key, keyczar.keys.RsaPublicKey):
retval = value
@@ -64,9 +93,8 @@ class BaseEncryptedField(models.Field):
return value
-class EncryptedTextField(BaseEncryptedField):
- __metaclass__ = models.SubfieldBase
-
+class EncryptedTextField(six.with_metaclass(models.SubfieldBase,
+ BaseEncryptedField)):
def get_internal_type(self):
return 'TextField'
@@ -85,9 +113,8 @@ class EncryptedTextField(BaseEncryptedField):
return (field_class, args, kwargs)
-class EncryptedCharField(BaseEncryptedField):
- __metaclass__ = models.SubfieldBase
-
+class EncryptedCharField(six.with_metaclass(models.SubfieldBase,
+ BaseEncryptedField)):
def __init__(self, *args, **kwargs):
super(EncryptedCharField, self).__init__(*args, **kwargs)
@@ -107,4 +134,3 @@ class EncryptedCharField(BaseEncryptedField):
args, kwargs = introspector(self)
# That's our definition!
return (field_class, args, kwargs)
-
diff --git a/awx/lib/site-packages/django_extensions/db/fields/json.py b/awx/lib/site-packages/django_extensions/db/fields/json.py
index e1b1b51602..51d5b1dd52 100644
--- a/awx/lib/site-packages/django_extensions/db/fields/json.py
+++ b/awx/lib/site-packages/django_extensions/db/fields/json.py
@@ -58,16 +58,13 @@ class JSONList(list):
return dumps(self)
-class JSONField(models.TextField):
+class JSONField(six.with_metaclass(models.SubfieldBase, models.TextField)):
"""JSONField is a generic textfield that neatly serializes/unserializes
JSON objects seamlessly. Main thingy must be a dict object."""
- # Used so to_python() is called
- __metaclass__ = models.SubfieldBase
-
def __init__(self, *args, **kwargs):
- default = kwargs.get('default')
- if not default:
+ default = kwargs.get('default', None)
+ if default is None:
kwargs['default'] = '{}'
elif isinstance(default, (list, dict)):
kwargs['default'] = dumps(default)
diff --git a/awx/lib/site-packages/django_extensions/future_1_5.py b/awx/lib/site-packages/django_extensions/future_1_5.py
new file mode 100644
index 0000000000..d7144ca383
--- /dev/null
+++ b/awx/lib/site-packages/django_extensions/future_1_5.py
@@ -0,0 +1,16 @@
+"""
+A forwards compatibility module.
+
+Implements some features of Django 1.5 related to the 'Custom User Model' feature
+when the application is run with a lower version of Django.
+"""
+from __future__ import unicode_literals
+
+from django.contrib.auth.models import User
+
+User.USERNAME_FIELD = "username"
+User.get_username = lambda self: self.username
+
+
+def get_user_model():
+ return User
diff --git a/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py b/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py
index 299a0201fe..5caf40efa3 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/dumpscript.py
@@ -71,9 +71,9 @@ def orm_item_locator(orm_obj):
for key in clean_dict:
v = clean_dict[key]
if v is not None and not isinstance(v, (six.string_types, six.integer_types, float, datetime.datetime)):
- clean_dict[key] = u"%s" % v
+ clean_dict[key] = six.u("%s" % v)
- output = """ locate_object(%s, "%s", %s, "%s", %s, %s ) """ % (
+ output = """ importer.locate_object(%s, "%s", %s, "%s", %s, %s ) """ % (
original_class, original_pk_name,
the_class, pk_name, pk_value, clean_dict
)
@@ -264,7 +264,7 @@ class InstanceCode(Code):
# Print the save command for our new object
# e.g. model_name_35.save()
if code_lines:
- code_lines.append("%s = save_or_locate(%s)\n" % (self.variable_name, self.variable_name))
+ code_lines.append("%s = importer.save_or_locate(%s)\n" % (self.variable_name, self.variable_name))
code_lines += self.get_many_to_many_lines(force=force)
@@ -499,7 +499,6 @@ class Script(Code):
code.insert(2, "")
for key, value in self.context["__extra_imports"].items():
code.insert(2, " from %s import %s" % (value, key))
- code.insert(2 + len(self.context["__extra_imports"]), self.locate_object_function)
return code
@@ -513,11 +512,17 @@ class Script(Code):
# This file has been automatically generated.
# Instead of changing it, create a file called import_helper.py
-# which this script has hooks to.
+# and put there a class called ImportHelper(object) in it.
#
-# On that file, don't forget to add the necessary Django imports
-# and take a look at how locate_object() and save_or_locate()
-# are implemented here and expected to behave.
+# This class will be specially casted so that instead of extending object,
+# it will actually extend the class BasicImportHelper()
+#
+# That means you just have to overload the methods you want to
+# change, leaving the other ones inteact.
+#
+# Something that you might want to do is use transactions, for example.
+#
+# Also, don't forget to add the necessary Django imports.
#
# This file was generated with the following command:
# %s
@@ -530,24 +535,31 @@ class Script(Code):
# you must make sure ./some_folder/__init__.py exists
# and run ./manage.py runscript some_folder.some_script
+from django.db import transaction
-IMPORT_HELPER_AVAILABLE = False
-try:
- import import_helper
- IMPORT_HELPER_AVAILABLE = True
-except ImportError:
- pass
+class BasicImportHelper(object):
-import datetime
-from decimal import Decimal
-from django.contrib.contenttypes.models import ContentType
+ def pre_import(self):
+ pass
-def run():
+ # You probably want to uncomment on of these two lines
+ # @transaction.atomic # Django 1.6
+ # @transaction.commit_on_success # Django <1.6
+ def run_import(self, import_data):
+ import_data()
-""" % " ".join(sys.argv)
+ def post_import(self):
+ pass
+
+ def locate_similar(self, current_object, search_data):
+ #you will probably want to call this method from save_or_locate()
+ #example:
+ #new_obj = self.locate_similar(the_obj, {"national_id": the_obj.national_id } )
- locate_object_function = """
- def locate_object(original_class, original_pk_name, the_class, pk_name, pk_value, obj_content):
+ the_obj = current_object.__class__.objects.get(**search_data)
+ return the_obj
+
+ def locate_object(self, original_class, original_pk_name, the_class, pk_name, pk_value, obj_content):
#You may change this function to do specific lookup for specific objects
#
#original_class class of the django orm's object that needs to be located
@@ -571,22 +583,55 @@ def run():
#if the_class == StaffGroup:
# pk_value=8
-
- if IMPORT_HELPER_AVAILABLE and hasattr(import_helper, "locate_object"):
- return import_helper.locate_object(original_class, original_pk_name, the_class, pk_name, pk_value, obj_content)
-
search_data = { pk_name: pk_value }
- the_obj =the_class.objects.get(**search_data)
+ the_obj = the_class.objects.get(**search_data)
+ #print(the_obj)
return the_obj
- def save_or_locate(the_obj):
- if IMPORT_HELPER_AVAILABLE and hasattr(import_helper, "save_or_locate"):
- the_obj = import_helper.save_or_locate(the_obj)
- else:
+
+ def save_or_locate(self, the_obj):
+ #change this if you want to locate the object in the database
+ try:
the_obj.save()
+ except:
+ print("---------------")
+ print("Error saving the following object:")
+ print(the_obj.__class__)
+ print(" ")
+ print(the_obj.__dict__)
+ print(" ")
+ print(the_obj)
+ print(" ")
+ print("---------------")
+
+ raise
return the_obj
-"""
+
+importer = None
+try:
+ import import_helper
+ #we need this so ImportHelper can extend BasicImportHelper, although import_helper.py
+ #has no knowlodge of this class
+ importer = type("DynamicImportHelper", (import_helper.ImportHelper, BasicImportHelper ) , {} )()
+except ImportError as e:
+ if str(e) == "No module named import_helper":
+ importer = BasicImportHelper()
+ else:
+ raise
+
+import datetime
+from decimal import Decimal
+from django.contrib.contenttypes.models import ContentType
+
+def run():
+ importer.pre_import()
+ importer.run_import(import_data)
+ importer.post_import()
+
+def import_data():
+
+""" % " ".join(sys.argv)
# HELPER FUNCTIONS
diff --git a/awx/lib/site-packages/django_extensions/management/commands/export_emails.py b/awx/lib/site-packages/django_extensions/management/commands/export_emails.py
index 7766a1b6e2..05a8689b05 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/export_emails.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/export_emails.py
@@ -1,8 +1,13 @@
from django.core.management.base import BaseCommand, CommandError
-from django.contrib.auth.models import User, Group
+try:
+ from django.contrib.auth import get_user_model # Django 1.5
+except ImportError:
+ from django_extensions.future_1_5 import get_user_model
+from django.contrib.auth.models import Group
from optparse import make_option
from sys import stdout
from csv import writer
+import six
FORMATS = [
'address',
@@ -15,7 +20,7 @@ FORMATS = [
def full_name(first_name, last_name, username, **extra):
- name = u" ".join(n for n in [first_name, last_name] if n)
+ name = six.u(" ").join(n for n in [first_name, last_name] if n)
if not name:
return username
return name
@@ -42,7 +47,7 @@ class Command(BaseCommand):
raise CommandError("extra arguments supplied")
group = options['group']
if group and not Group.objects.filter(name=group).count() == 1:
- names = u"', '".join(g['name'] for g in Group.objects.values('name')).encode('utf-8')
+ names = six.u("', '").join(g['name'] for g in Group.objects.values('name')).encode('utf-8')
if names:
names = "'" + names + "'."
raise CommandError("Unknown group '" + group + "'. Valid group names are: " + names)
@@ -51,6 +56,7 @@ class Command(BaseCommand):
else:
outfile = stdout
+ User = get_user_model()
qs = User.objects.all().order_by('last_name', 'first_name', 'username', 'email')
if group:
qs = qs.filter(group__name=group).distinct()
@@ -61,15 +67,15 @@ class Command(BaseCommand):
"""simple single entry per line in the format of:
"full name" <my@address.com>;
"""
- out.write(u"\n".join(u'"%s" <%s>;' % (full_name(**ent), ent['email'])
- for ent in qs).encode(self.encoding))
+ out.write(six.u("\n").join(six.u('"%s" <%s>;' % (full_name(**ent), ent['email']))
+ for ent in qs).encode(self.encoding))
out.write("\n")
def emails(self, qs, out):
"""simpler single entry with email only in the format of:
my@address.com,
"""
- out.write(u",\n".join(u'%s' % (ent['email']) for ent in qs).encode(self.encoding))
+ out.write(six.u(",\n").join(six.u('%s' % (ent['email'])) for ent in qs).encode(self.encoding))
out.write("\n")
def google(self, qs, out):
diff --git a/awx/lib/site-packages/django_extensions/management/commands/graph_models.py b/awx/lib/site-packages/django_extensions/management/commands/graph_models.py
index 4ff03bb3aa..59403e248c 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/graph_models.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/graph_models.py
@@ -56,7 +56,7 @@ class Command(BaseCommand):
vizdata = ' '.join(dotdata.split("\n")).strip().encode('utf-8')
version = pygraphviz.__version__.rstrip("-svn")
try:
- if [int(v) for v in version.split('.')] < (0, 36):
+ if tuple(int(v) for v in version.split('.')) < (0, 36):
# HACK around old/broken AGraph before version 0.36 (ubuntu ships with this old version)
import tempfile
tmpfile = tempfile.NamedTemporaryFile()
diff --git a/awx/lib/site-packages/django_extensions/management/commands/passwd.py b/awx/lib/site-packages/django_extensions/management/commands/passwd.py
index 1bd58ab524..6f084b143d 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/passwd.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/passwd.py
@@ -1,5 +1,8 @@
from django.core.management.base import BaseCommand, CommandError
-from django.contrib.auth.models import User
+try:
+ from django.contrib.auth import get_user_model # Django 1.5
+except ImportError:
+ from django_extensions.future_1_5 import get_user_model
import getpass
@@ -17,6 +20,7 @@ class Command(BaseCommand):
else:
username = getpass.getuser()
+ User = get_user_model()
try:
u = User.objects.get(username=username)
except User.DoesNotExist:
diff --git a/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py b/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py
index 099ef86050..0cc7b3e27d 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/pipchecker.py
@@ -8,6 +8,12 @@ import urlparse
import xmlrpclib
from distutils.version import LooseVersion
+try:
+ import requests
+except ImportError:
+ print("""The requests library is not installed. To continue:
+ pip install requests""")
+
from optparse import make_option
from django.core.management.base import NoArgsCommand
@@ -166,13 +172,25 @@ class Command(NoArgsCommand):
}
if self.github_api_token:
headers["Authorization"] = "token {0}".format(self.github_api_token)
- user, repo = urlparse.urlparse(req_url).path.split("#")[0].strip("/").rstrip("/").split("/")
+ try:
+ user, repo = urlparse.urlparse(req_url).path.split("#")[0].strip("/").rstrip("/").split("/")
+ except (ValueError, IndexError) as e:
+ print("\nFailed to parse %r: %s\n" % (req_url, e))
+ continue
+
+ try:
+ #test_auth = self._urlopen_as_json("https://api.github.com/django/", headers=headers)
+ test_auth = requests.get("https://api.github.com/django/", headers=headers).json()
+ except urllib2.HTTPError as e:
+ print("\n%s\n" % str(e))
+ return
- test_auth = self._urlopen_as_json("https://api.github.com/django/", headers=headers)
if "message" in test_auth and test_auth["message"] == "Bad credentials":
- sys.exit("\nGithub API: Bad credentials. Aborting!\n")
+ print("\nGithub API: Bad credentials. Aborting!\n")
+ return
elif "message" in test_auth and test_auth["message"].startswith("API Rate Limit Exceeded"):
- sys.exit("\nGithub API: Rate Limit Exceeded. Aborting!\n")
+ print("\nGithub API: Rate Limit Exceeded. Aborting!\n")
+ return
if ".git" in repo:
repo_name, frozen_commit_full = repo.split(".git")
@@ -186,11 +204,14 @@ class Command(NoArgsCommand):
if frozen_commit_sha:
branch_url = "https://api.github.com/repos/{0}/{1}/branches".format(user, repo_name)
- branch_data = self._urlopen_as_json(branch_url, headers=headers)
-
- frozen_commit_url = "https://api.github.com/repos/{0}/{1}/commits/{2}" \
- .format(user, repo_name, frozen_commit_sha)
- frozen_commit_data = self._urlopen_as_json(frozen_commit_url, headers=headers)
+ #branch_data = self._urlopen_as_json(branch_url, headers=headers)
+ branch_data = requests.get(branch_url, headers=headers).json()
+
+ frozen_commit_url = "https://api.github.com/repos/{0}/{1}/commits/{2}".format(
+ user, repo_name, frozen_commit_sha
+ )
+ #frozen_commit_data = self._urlopen_as_json(frozen_commit_url, headers=headers)
+ frozen_commit_data = requests.get(frozen_commit_url, headers=headers).json()
if "message" in frozen_commit_data and frozen_commit_data["message"] == "Not Found":
msg = "{0} not found in {1}. Repo may be private.".format(frozen_commit_sha[:10], name)
diff --git a/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py b/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py
index 695c050bd8..000453aa81 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/print_user_for_session.py
@@ -1,5 +1,8 @@
from django.core.management.base import BaseCommand, CommandError
-from django.contrib.auth.models import User
+try:
+ from django.contrib.auth import get_user_model # Django 1.5
+except ImportError:
+ from django_extensions.future_1_5 import get_user_model
from django.contrib.sessions.models import Session
import re
@@ -38,6 +41,7 @@ class Command(BaseCommand):
print('No user associated with session')
return
print("User id: %s" % uid)
+ User = get_user_model()
try:
user = User.objects.get(pk=uid)
except User.DoesNotExist:
diff --git a/awx/lib/site-packages/django_extensions/management/commands/reset_db.py b/awx/lib/site-packages/django_extensions/management/commands/reset_db.py
index 71d59818f6..e3f3fad3b6 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/reset_db.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/reset_db.py
@@ -92,10 +92,10 @@ Type 'yes' to continue, or 'no' to cancel: """ % (settings.DATABASE_NAME,))
if password is None:
password = settings.DATABASE_PASSWORD
- if engine == 'sqlite3':
+ if engine in ('sqlite3', 'spatialite'):
import os
try:
- logging.info("Unlinking sqlite3 database")
+ logging.info("Unlinking %s database" % engine)
os.unlink(settings.DATABASE_NAME)
except OSError:
pass
diff --git a/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py b/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py
index 47cd87b917..7798a284e6 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/runprofileserver.py
@@ -116,6 +116,8 @@ class Command(BaseCommand):
help='Specifies the directory from which to serve admin media.'),
make_option('--prof-path', dest='prof_path', default='/tmp',
help='Specifies the directory which to save profile information in.'),
+ make_option('--prof-file', dest='prof_file', default='{path}.{duration:06d}ms.{time}',
+ help='Set filename format, default if "{path}.{duration:06d}ms.{time}".'),
make_option('--nomedia', action='store_true', dest='no_media', default=False,
help='Do not profile MEDIA_URL and ADMIN_MEDIA_URL'),
make_option('--use-cprofile', action='store_true', dest='use_cprofile', default=False,
@@ -186,6 +188,11 @@ class Command(BaseCommand):
raise SystemExit("Kcachegrind compatible output format required cProfile from Python 2.5")
prof_path = options.get('prof_path', '/tmp')
+ prof_file = options.get('prof_file', '{path}.{duration:06d}ms.{time}')
+ if not prof_file.format(path='1', duration=2, time=3):
+ prof_file = '{path}.{duration:06d}ms.{time}'
+ print("Filename format is wrong. Default format used: '{path}.{duration:06d}ms.{time}'.")
+
def get_exclude_paths():
exclude_paths = []
media_url = getattr(settings, 'MEDIA_URL', None)
@@ -225,8 +232,8 @@ class Command(BaseCommand):
kg.output(open(profname, 'w'))
elif USE_CPROFILE:
prof.dump_stats(profname)
- profname2 = "%s.%06dms.%d.prof" % (path_name, elapms, time.time())
- profname2 = os.path.join(prof_path, profname2)
+ profname2 = prof_file.format(path=path_name, duration=int(elapms), time=int(time.time()))
+ profname2 = os.path.join(prof_path, "%s.prof" % profname2)
if not USE_CPROFILE:
prof.close()
os.rename(profname, profname2)
@@ -278,4 +285,3 @@ class Command(BaseCommand):
autoreload.main(inner_run)
else:
inner_run()
-
diff --git a/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py b/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py
index 8899a6b687..678e8eb91b 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/runserver_plus.py
@@ -3,15 +3,31 @@ from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import setup_logger, RedirectHandler
from optparse import make_option
import os
+import re
+import socket
import sys
import time
try:
- from django.contrib.staticfiles.handlers import StaticFilesHandler
- USE_STATICFILES = 'django.contrib.staticfiles' in settings.INSTALLED_APPS
+ if 'django.contrib.staticfiles' in settings.INSTALLED_APPS:
+ from django.contrib.staticfiles.handlers import StaticFilesHandler
+ USE_STATICFILES = True
+ elif 'staticfiles' in settings.INSTALLED_APPS:
+ from staticfiles.handlers import StaticFilesHandler # noqa
+ USE_STATICFILES = True
+ else:
+ USE_STATICFILES = False
except ImportError:
USE_STATICFILES = False
+naiveip_re = re.compile(r"""^(?:
+(?P<addr>
+ (?P<ipv4>\d{1,3}(?:\.\d{1,3}){3}) | # IPv4 address
+ (?P<ipv6>\[[a-fA-F0-9:]+\]) | # IPv6 address
+ (?P<fqdn>[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*) # FQDN
+):)?(?P<port>\d+)$""", re.X)
+DEFAULT_PORT = "8000"
+
import logging
logger = logging.getLogger(__name__)
@@ -20,6 +36,8 @@ from django_extensions.management.technical_response import null_technical_500_r
class Command(BaseCommand):
option_list = BaseCommand.option_list + (
+ make_option('--ipv6', '-6', action='store_true', dest='use_ipv6', default=False,
+ help='Tells Django to use a IPv6 address.'),
make_option('--noreload', action='store_false', dest='use_reloader', default=True,
help='Tells Django to NOT use the auto-reloader.'),
make_option('--browser', action='store_true', dest='open_browser',
@@ -103,33 +121,51 @@ class Command(BaseCommand):
from django.views import debug
debug.technical_500_response = null_technical_500_response
- if args:
- raise CommandError('Usage is runserver %s' % self.args)
+ self.use_ipv6 = options.get('use_ipv6')
+ if self.use_ipv6 and not socket.has_ipv6:
+ raise CommandError('Your Python does not support IPv6.')
+ self._raw_ipv6 = False
if not addrport:
- addr = ''
- port = '8000'
- else:
try:
- addr, port = addrport.split(':')
- except ValueError:
- addr, port = '', addrport
- if not addr:
- addr = '127.0.0.1'
-
- if not port.isdigit():
- raise CommandError("%r is not a valid port number." % port)
+ addrport = settings.RUNSERVERPLUS_SERVER_ADDRESS_PORT
+ except AttributeError:
+ pass
+ if not addrport:
+ self.addr = ''
+ self.port = DEFAULT_PORT
+ else:
+ m = re.match(naiveip_re, addrport)
+ if m is None:
+ raise CommandError('"%s" is not a valid port number '
+ 'or address:port pair.' % addrport)
+ self.addr, _ipv4, _ipv6, _fqdn, self.port = m.groups()
+ if not self.port.isdigit():
+ raise CommandError("%r is not a valid port number." %
+ self.port)
+ if self.addr:
+ if _ipv6:
+ self.addr = self.addr[1:-1]
+ self.use_ipv6 = True
+ self._raw_ipv6 = True
+ elif self.use_ipv6 and not _fqdn:
+ raise CommandError('"%s" is not a valid IPv6 address.'
+ % self.addr)
+ if not self.addr:
+ self.addr = '::1' if self.use_ipv6 else '127.0.0.1'
threaded = options.get('threaded', False)
use_reloader = options.get('use_reloader', True)
open_browser = options.get('open_browser', False)
cert_path = options.get("cert_path")
quit_command = (sys.platform == 'win32') and 'CTRL-BREAK' or 'CONTROL-C'
+ bind_url = "http://%s:%s/" % (
+ self.addr if not self._raw_ipv6 else '[%s]' % self.addr, self.port)
def inner_run():
print("Validating models...")
self.validate(display_num_errors=True)
print("\nDjango version %s, using settings %r" % (django.get_version(), settings.SETTINGS_MODULE))
- print("Development server is running at http://%s:%s/" % (addr, port))
+ print("Development server is running at %s" % (bind_url,))
print("Using the Werkzeug debugger (http://werkzeug.pocoo.org/)")
print("Quit the server with %s." % quit_command)
path = options.get('admin_media_path', '')
@@ -149,8 +185,7 @@ class Command(BaseCommand):
handler = StaticFilesHandler(handler)
if open_browser:
import webbrowser
- url = "http://%s:%s/" % (addr, port)
- webbrowser.open(url)
+ webbrowser.open(bind_url)
if cert_path:
"""
OpenSSL is needed for SSL support.
@@ -189,8 +224,8 @@ class Command(BaseCommand):
else:
ssl_context = None
run_simple(
- addr,
- int(port),
+ self.addr,
+ int(self.port),
DebuggedApplication(handler, True),
use_reloader=use_reloader,
use_debugger=True,
diff --git a/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py b/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py
index df1e387a69..effeb476e2 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/set_fake_emails.py
@@ -38,7 +38,11 @@ class Command(NoArgsCommand):
if not settings.DEBUG:
raise CommandError('Only available in debug mode')
- from django.contrib.auth.models import User, Group
+ try:
+ from django.contrib.auth import get_user_model # Django 1.5
+ except ImportError:
+ from django_extensions.future_1_5 import get_user_model
+ from django.contrib.auth.models import Group
email = options.get('default_email', DEFAULT_FAKE_EMAIL)
include_regexp = options.get('include_regexp', None)
exclude_regexp = options.get('exclude_regexp', None)
@@ -47,6 +51,7 @@ class Command(NoArgsCommand):
no_admin = options.get('no_admin', False)
no_staff = options.get('no_staff', False)
+ User = get_user_model()
users = User.objects.all()
if no_admin:
users = users.exclude(is_superuser=True)
diff --git a/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py b/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py
index 3ce2437548..e502fbfdb4 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/set_fake_passwords.py
@@ -28,7 +28,11 @@ class Command(NoArgsCommand):
if not settings.DEBUG:
raise CommandError('Only available in debug mode')
- from django.contrib.auth.models import User
+ try:
+ from django.contrib.auth import get_user_model # Django 1.5
+ except ImportError:
+ from django_extensions.future_1_5 import get_user_model
+
if options.get('prompt_passwd', False):
from getpass import getpass
passwd = getpass('Password: ')
@@ -37,6 +41,7 @@ class Command(NoArgsCommand):
else:
passwd = options.get('default_passwd', DEFAULT_FAKE_PASSWORD)
+ User = get_user_model()
user = User()
user.set_password(passwd)
count = User.objects.all().update(password=user.password)
diff --git a/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py b/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py
index 19527ecdbe..8eacf127a6 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/sqlcreate.py
@@ -1,5 +1,6 @@
from optparse import make_option
import sys
+import socket
import django
from django.core.management.base import CommandError, BaseCommand
@@ -57,6 +58,7 @@ The envisioned use case is something like this:
dbuser = settings.DATABASE_USER
dbpass = settings.DATABASE_PASSWORD
dbhost = settings.DATABASE_HOST
+ dbclient = socket.gethostname()
# django settings file tells you that localhost should be specified by leaving
# the DATABASE_HOST blank
@@ -69,7 +71,7 @@ The envisioned use case is something like this:
""")
print("CREATE DATABASE %s CHARACTER SET utf8 COLLATE utf8_bin;" % dbname)
print("GRANT ALL PRIVILEGES ON %s.* to '%s'@'%s' identified by '%s';" % (
- dbname, dbuser, dbhost, dbpass
+ dbname, dbuser, dbclient, dbpass
))
elif engine == 'postgresql_psycopg2':
if options.get('drop'):
diff --git a/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py b/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py
index b2049a12c7..d89d218eef 100644
--- a/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py
+++ b/awx/lib/site-packages/django_extensions/management/commands/sqldiff.py
@@ -228,7 +228,7 @@ class SQLDiff(object):
def strip_parameters(self, field_type):
if field_type and field_type != 'double precision':
- return field_type.split(" ")[0].split("(")[0]
+ return field_type.split(" ")[0].split("(")[0].lower()
return field_type
def find_unique_missing_in_db(self, meta, table_indexes, table_name):
@@ -289,14 +289,14 @@ class SQLDiff(object):
continue
description = db_fields[field.name]
- model_type = self.strip_parameters(self.get_field_model_type(field))
- db_type = self.strip_parameters(self.get_field_db_type(description, field))
+ model_type = self.get_field_model_type(field)
+ db_type = self.get_field_db_type(description, field)
# use callback function if defined
if func:
model_type, db_type = func(field, description, model_type, db_type)
- if not model_type == db_type:
+ if not self.strip_parameters(db_type) == self.strip_parameters(model_type):
self.add_difference('field-type-differ', table_name, field.name, model_type, db_type)
def find_field_parameter_differ(self, meta, table_description, table_name, func=None):
diff --git a/awx/lib/site-packages/django_extensions/management/modelviz.py b/awx/lib/site-packages/django_extensions/management/modelviz.py
index 83ac484798..1d18d16972 100644
--- a/awx/lib/site-packages/django_extensions/management/modelviz.py
+++ b/awx/lib/site-packages/django_extensions/management/modelviz.py
@@ -210,9 +210,9 @@ def generate_dot(app_labels, **kwargs):
if skip_field(field):
continue
if isinstance(field, OneToOneField):
- add_relation(field, '[arrowhead=none, arrowtail=none]')
+ add_relation(field, '[arrowhead=none, arrowtail=none, dir=both]')
elif isinstance(field, ForeignKey):
- add_relation(field, '[arrowhead=none, arrowtail=dot]')
+ add_relation(field, '[arrowhead=none, arrowtail=dot, dir=both]')
for field in appmodel._meta.local_many_to_many:
if skip_field(field):
@@ -240,7 +240,7 @@ def generate_dot(app_labels, **kwargs):
'type': "inheritance",
'name': "inheritance",
'label': l,
- 'arrows': '[arrowhead=empty, arrowtail=none]',
+ 'arrows': '[arrowhead=empty, arrowtail=none, dir=both]',
'needs_node': True
}
# TODO: seems as if abstract models aren't part of models.getModels, which is why they are printed by this without any attributes.
diff --git a/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py b/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py
index 4ffcc3ec17..d36ddb8405 100644
--- a/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py
+++ b/awx/lib/site-packages/django_extensions/mongodb/fields/__init__.py
@@ -69,7 +69,7 @@ class AutoSlugField(SlugField):
raise ValueError("missing 'populate_from' argument")
else:
self._populate_from = populate_from
- self.separator = kwargs.pop('separator', u'-')
+ self.separator = kwargs.pop('separator', six.u('-'))
self.overwrite = kwargs.pop('overwrite', False)
super(AutoSlugField, self).__init__(*args, **kwargs)
diff --git a/awx/lib/site-packages/django_extensions/templatetags/widont.py b/awx/lib/site-packages/django_extensions/templatetags/widont.py
index 687e5114ee..d42833f941 100644
--- a/awx/lib/site-packages/django_extensions/templatetags/widont.py
+++ b/awx/lib/site-packages/django_extensions/templatetags/widont.py
@@ -1,6 +1,7 @@
from django.template import Library
from django.utils.encoding import force_unicode
import re
+import six
register = Library()
re_widont = re.compile(r'\s+(\S+\s*)$')
@@ -24,7 +25,7 @@ def widont(value, count=1):
NoEffect
"""
def replace(matchobj):
- return u'&nbsp;%s' % matchobj.group(1)
+ return six.u('&nbsp;%s' % matchobj.group(1))
for i in range(count):
value = re_widont.sub(replace, force_unicode(value))
return value
@@ -48,7 +49,7 @@ def widont_html(value):
leading&nbsp;text <p>test me&nbsp;out</p> trailing&nbsp;text
"""
def replace(matchobj):
- return u'%s&nbsp;%s%s' % matchobj.groups()
+ return six.u('%s&nbsp;%s%s' % matchobj.groups())
return re_widont_html.sub(replace, force_unicode(value))
register.filter(widont)
diff --git a/awx/lib/site-packages/django_extensions/tests/json_field.py b/awx/lib/site-packages/django_extensions/tests/json_field.py
index a73d1f6706..e9aed0ffc0 100644
--- a/awx/lib/site-packages/django_extensions/tests/json_field.py
+++ b/awx/lib/site-packages/django_extensions/tests/json_field.py
@@ -25,9 +25,13 @@ class JsonFieldTest(unittest.TestCase):
def testCharFieldCreate(self):
j = TestModel.objects.create(a=6, j_field=dict(foo='bar'))
- self.assertEquals(j.a, 6)
+ self.assertEqual(j.a, 6)
+
+ def testDefault(self):
+ j = TestModel.objects.create(a=1)
+ self.assertEqual(j.j_field, {})
def testEmptyList(self):
j = TestModel.objects.create(a=6, j_field=[])
self.assertTrue(isinstance(j.j_field, list))
- self.assertEquals(j.j_field, [])
+ self.assertEqual(j.j_field, [])
diff --git a/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py b/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py
index b25dae0537..dd4b2190b2 100644
--- a/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py
+++ b/awx/lib/site-packages/django_extensions/tests/test_dumpscript.py
@@ -52,7 +52,7 @@ class DumpScriptTests(TestCase):
tmp_out = StringIO()
call_command('dumpscript', 'tests', stdout=tmp_out)
self.assertTrue('Mike' in tmp_out.getvalue()) # script should go to tmp_out
- self.assertEquals(0, len(sys.stdout.getvalue())) # there should not be any output to sys.stdout
+ self.assertEqual(0, len(sys.stdout.getvalue())) # there should not be any output to sys.stdout
tmp_out.close()
#----------------------------------------------------------------------
@@ -65,7 +65,7 @@ class DumpScriptTests(TestCase):
call_command('dumpscript', 'tests', stderr=tmp_err)
self.assertTrue('Fred' in sys.stdout.getvalue()) # script should still go to stdout
self.assertTrue('Name' in tmp_err.getvalue()) # error output should go to tmp_err
- self.assertEquals(0, len(sys.stderr.getvalue())) # there should not be any output to sys.stderr
+ self.assertEqual(0, len(sys.stderr.getvalue())) # there should not be any output to sys.stderr
tmp_err.close()
#----------------------------------------------------------------------
diff --git a/awx/lib/site-packages/django_extensions/tests/utils.py b/awx/lib/site-packages/django_extensions/tests/utils.py
index c91989f9f4..23935b66e0 100644
--- a/awx/lib/site-packages/django_extensions/tests/utils.py
+++ b/awx/lib/site-packages/django_extensions/tests/utils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import sys
+import six
from django.test import TestCase
from django.utils.unittest import skipIf
@@ -14,21 +15,21 @@ except ImportError:
class TruncateLetterTests(TestCase):
def test_truncate_more_than_text_length(self):
- self.assertEquals(u"hello tests", truncate_letters("hello tests", 100))
+ self.assertEqual(six.u("hello tests"), truncate_letters("hello tests", 100))
def test_truncate_text(self):
- self.assertEquals(u"hello...", truncate_letters("hello tests", 5))
+ self.assertEqual(six.u("hello..."), truncate_letters("hello tests", 5))
def test_truncate_with_range(self):
for i in range(10, -1, -1):
self.assertEqual(
- u'hello tests'[:i] + '...',
+ six.u('hello tests'[:i]) + '...',
truncate_letters("hello tests", i)
)
def test_with_non_ascii_characters(self):
- self.assertEquals(
- u'\u5ce0 (\u3068\u3046\u3052 t\u014dg...',
+ self.assertEqual(
+ six.u('\u5ce0 (\u3068\u3046\u3052 t\u014dg...'),
truncate_letters("峠 (とうげ tōge - mountain pass)", 10)
)
@@ -37,7 +38,7 @@ class UUIDTests(TestCase):
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_uuid3(self):
# make a UUID using an MD5 hash of a namespace UUID and a name
- self.assertEquals(
+ self.assertEqual(
uuid.UUID('6fa459ea-ee8a-3ca4-894e-db77e160355e'),
uuid.uuid3(uuid.NAMESPACE_DNS, 'python.org')
)
@@ -45,7 +46,7 @@ class UUIDTests(TestCase):
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_uuid5(self):
# make a UUID using a SHA-1 hash of a namespace UUID and a name
- self.assertEquals(
+ self.assertEqual(
uuid.UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d'),
uuid.uuid5(uuid.NAMESPACE_DNS, 'python.org')
)
@@ -55,21 +56,21 @@ class UUIDTests(TestCase):
# make a UUID from a string of hex digits (braces and hyphens ignored)
x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')
# convert a UUID to a string of hex digits in standard form
- self.assertEquals('00010203-0405-0607-0809-0a0b0c0d0e0f', str(x))
+ self.assertEqual('00010203-0405-0607-0809-0a0b0c0d0e0f', str(x))
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_uuid_bytes(self):
# make a UUID from a string of hex digits (braces and hyphens ignored)
x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')
# get the raw 16 bytes of the UUID
- self.assertEquals(
+ self.assertEqual(
'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f',
x.bytes
)
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_make_uuid_from_byte_string(self):
- self.assertEquals(
+ self.assertEqual(
uuid.UUID(bytes='\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f'),
uuid.UUID('00010203-0405-0607-0809-0a0b0c0d0e0f')
)
diff --git a/awx/lib/site-packages/django_extensions/tests/uuid_field.py b/awx/lib/site-packages/django_extensions/tests/uuid_field.py
index 43416e5734..823ff2e2ba 100644
--- a/awx/lib/site-packages/django_extensions/tests/uuid_field.py
+++ b/awx/lib/site-packages/django_extensions/tests/uuid_field.py
@@ -1,3 +1,4 @@
+import six
from django.conf import settings
from django.core.management import call_command
from django.db.models import loading
@@ -36,20 +37,22 @@ class UUIDFieldTest(unittest.TestCase):
settings.INSTALLED_APPS = self.old_installed_apps
def testUUIDFieldCreate(self):
- j = TestModel_field.objects.create(a=6, uuid_field=u'550e8400-e29b-41d4-a716-446655440000')
- self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440000')
+ j = TestModel_field.objects.create(a=6, uuid_field=six.u('550e8400-e29b-41d4-a716-446655440000'))
+ self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440000'))
def testUUIDField_pkCreate(self):
- j = TestModel_pk.objects.create(uuid_field=u'550e8400-e29b-41d4-a716-446655440000')
- self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440000')
- self.assertEquals(j.pk, u'550e8400-e29b-41d4-a716-446655440000')
+ j = TestModel_pk.objects.create(uuid_field=six.u('550e8400-e29b-41d4-a716-446655440000'))
+ self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440000'))
+ self.assertEqual(j.pk, six.u('550e8400-e29b-41d4-a716-446655440000'))
def testUUIDField_pkAgregateCreate(self):
- j = TestAgregateModel.objects.create(a=6)
- self.assertEquals(j.a, 6)
+ j = TestAgregateModel.objects.create(a=6, uuid_field=six.u('550e8400-e29b-41d4-a716-446655440001'))
+ self.assertEqual(j.a, 6)
+ self.assertIsInstance(j.pk, six.string_types)
+ self.assertEqual(len(j.pk), 36)
def testUUIDFieldManyToManyCreate(self):
- j = TestManyToManyModel.objects.create(uuid_field=u'550e8400-e29b-41d4-a716-446655440010')
- self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440010')
- self.assertEquals(j.pk, u'550e8400-e29b-41d4-a716-446655440010')
+ j = TestManyToManyModel.objects.create(uuid_field=six.u('550e8400-e29b-41d4-a716-446655440010'))
+ self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440010'))
+ self.assertEqual(j.pk, six.u('550e8400-e29b-41d4-a716-446655440010'))
diff --git a/awx/lib/site-packages/django_extensions/utils/dia2django.py b/awx/lib/site-packages/django_extensions/utils/dia2django.py
index 792529b10a..91a48631df 100644
--- a/awx/lib/site-packages/django_extensions/utils/dia2django.py
+++ b/awx/lib/site-packages/django_extensions/utils/dia2django.py
@@ -17,6 +17,7 @@ import sys
import gzip
from xml.dom.minidom import * # NOQA
import re
+import six
#Type dictionary translation types SQL -> Django
tsd = {
@@ -75,7 +76,7 @@ def dia2django(archivo):
datos = ppal.getElementsByTagName("dia:diagram")[0].getElementsByTagName("dia:layer")[0].getElementsByTagName("dia:object")
clases = {}
herit = []
- imports = u""
+ imports = six.u("")
for i in datos:
#Look for the classes
if i.getAttribute("type") == "UML - Class":
@@ -165,7 +166,7 @@ def dia2django(archivo):
a = i.getElementsByTagName("dia:string")
for j in a:
if len(j.childNodes[0].data[1:-1]):
- imports += u"from %s.models import *" % j.childNodes[0].data[1:-1]
+ imports += six.u("from %s.models import *" % j.childNodes[0].data[1:-1])
addparentstofks(herit, clases)
#Ordering the appearance of classes
diff --git a/awx/lib/site-packages/django_extensions/utils/uuid.py b/awx/lib/site-packages/django_extensions/utils/uuid.py
deleted file mode 100644
index 2684f22019..0000000000
--- a/awx/lib/site-packages/django_extensions/utils/uuid.py
+++ /dev/null
@@ -1,566 +0,0 @@
-# flake8:noqa
-r"""UUID objects (universally unique identifiers) according to RFC 4122.
-
-This module provides immutable UUID objects (class UUID) and the functions
-uuid1(), uuid3(), uuid4(), uuid5() for generating version 1, 3, 4, and 5
-UUIDs as specified in RFC 4122.
-
-If all you want is a unique ID, you should probably call uuid1() or uuid4().
-Note that uuid1() may compromise privacy since it creates a UUID containing
-the computer's network address. uuid4() creates a random UUID.
-
-Typical usage:
-
- >>> import uuid
-
- # make a UUID based on the host ID and current time
- >>> uuid.uuid1()
- UUID('a8098c1a-f86e-11da-bd1a-00112444be1e')
-
- # make a UUID using an MD5 hash of a namespace UUID and a name
- >>> uuid.uuid3(uuid.NAMESPACE_DNS, 'python.org')
- UUID('6fa459ea-ee8a-3ca4-894e-db77e160355e')
-
- # make a random UUID
- >>> uuid.uuid4()
- UUID('16fd2706-8baf-433b-82eb-8c7fada847da')
-
- # make a UUID using a SHA-1 hash of a namespace UUID and a name
- >>> uuid.uuid5(uuid.NAMESPACE_DNS, 'python.org')
- UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d')
-
- # make a UUID from a string of hex digits (braces and hyphens ignored)
- >>> x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')
-
- # convert a UUID to a string of hex digits in standard form
- >>> str(x)
- '00010203-0405-0607-0809-0a0b0c0d0e0f'
-
- # get the raw 16 bytes of the UUID
- >>> x.bytes
- '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
-
- # make a UUID from a 16-byte string
- >>> uuid.UUID(bytes=x.bytes)
- UUID('00010203-0405-0607-0809-0a0b0c0d0e0f')
-"""
-
-__author__ = 'Ka-Ping Yee <ping@zesty.ca>'
-
-RESERVED_NCS, RFC_4122, RESERVED_MICROSOFT, RESERVED_FUTURE = [
- 'reserved for NCS compatibility', 'specified in RFC 4122',
- 'reserved for Microsoft compatibility', 'reserved for future definition'
-]
-
-
-class UUID(object):
- """Instances of the UUID class represent UUIDs as specified in RFC 4122.
- UUID objects are immutable, hashable, and usable as dictionary keys.
- Converting a UUID to a string with str() yields something in the form
- '12345678-1234-1234-1234-123456789abc'. The UUID constructor accepts
- five possible forms: a similar string of hexadecimal digits, or a tuple
- of six integer fields (with 32-bit, 16-bit, 16-bit, 8-bit, 8-bit, and
- 48-bit values respectively) as an argument named 'fields', or a string
- of 16 bytes (with all the integer fields in big-endian order) as an
- argument named 'bytes', or a string of 16 bytes (with the first three
- fields in little-endian order) as an argument named 'bytes_le', or a
- single 128-bit integer as an argument named 'int'.
-
- UUIDs have these read-only attributes:
-
- bytes the UUID as a 16-byte string (containing the six
- integer fields in big-endian byte order)
-
- bytes_le the UUID as a 16-byte string (with time_low, time_mid,
- and time_hi_version in little-endian byte order)
-
- fields a tuple of the six integer fields of the UUID,
- which are also available as six individual attributes
- and two derived attributes:
-
- time_low the first 32 bits of the UUID
- time_mid the next 16 bits of the UUID
- time_hi_version the next 16 bits of the UUID
- clock_seq_hi_variant the next 8 bits of the UUID
- clock_seq_low the next 8 bits of the UUID
- node the last 48 bits of the UUID
-
- time the 60-bit timestamp
- clock_seq the 14-bit sequence number
-
- hex the UUID as a 32-character hexadecimal string
-
- int the UUID as a 128-bit integer
-
- urn the UUID as a URN as specified in RFC 4122
-
- variant the UUID variant (one of the constants RESERVED_NCS,
- RFC_4122, RESERVED_MICROSOFT, or RESERVED_FUTURE)
-
- version the UUID version number (1 through 5, meaningful only
- when the variant is RFC_4122)
- """
-
- def __init__(self, hex=None, bytes=None, bytes_le=None, fields=None, int=None, version=None):
- r"""Create a UUID from either a string of 32 hexadecimal digits,
- a string of 16 bytes as the 'bytes' argument, a string of 16 bytes
- in little-endian order as the 'bytes_le' argument, a tuple of six
- integers (32-bit time_low, 16-bit time_mid, 16-bit time_hi_version,
- 8-bit clock_seq_hi_variant, 8-bit clock_seq_low, 48-bit node) as
- the 'fields' argument, or a single 128-bit integer as the 'int'
- argument. When a string of hex digits is given, curly braces,
- hyphens, and a URN prefix are all optional. For example, these
- expressions all yield the same UUID:
-
- UUID('{12345678-1234-5678-1234-567812345678}')
- UUID('12345678123456781234567812345678')
- UUID('urn:uuid:12345678-1234-5678-1234-567812345678')
- UUID(bytes='\x12\x34\x56\x78'*4)
- UUID(bytes_le='\x78\x56\x34\x12\x34\x12\x78\x56' +
- '\x12\x34\x56\x78\x12\x34\x56\x78')
- UUID(fields=(0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678))
- UUID(int=0x12345678123456781234567812345678)
-
- Exactly one of 'hex', 'bytes', 'bytes_le', 'fields', or 'int' must
- be given. The 'version' argument is optional; if given, the resulting
- UUID will have its variant and version set according to RFC 4122,
- overriding the given 'hex', 'bytes', 'bytes_le', 'fields', or 'int'.
- """
-
- if [hex, bytes, bytes_le, fields, int].count(None) != 4:
- raise TypeError('need one of hex, bytes, bytes_le, fields, or int')
- if hex is not None:
- hex = hex.replace('urn:', '').replace('uuid:', '')
- hex = hex.strip('{}').replace('-', '')
- if len(hex) != 32:
- raise ValueError('badly formed hexadecimal UUID string')
- int = long(hex, 16)
- if bytes_le is not None:
- if len(bytes_le) != 16:
- raise ValueError('bytes_le is not a 16-char string')
- bytes = (bytes_le[3] + bytes_le[2] + bytes_le[1] + bytes_le[0] +
- bytes_le[5] + bytes_le[4] + bytes_le[7] + bytes_le[6] +
- bytes_le[8:])
- if bytes is not None:
- if len(bytes) != 16:
- raise ValueError('bytes is not a 16-char string')
- int = long(('%02x' * 16) % tuple(map(ord, bytes)), 16)
- if fields is not None:
- if len(fields) != 6:
- raise ValueError('fields is not a 6-tuple')
- (time_low, time_mid, time_hi_version,
- clock_seq_hi_variant, clock_seq_low, node) = fields
- if not 0 <= time_low < 1 << 32L:
- raise ValueError('field 1 out of range (need a 32-bit value)')
- if not 0 <= time_mid < 1 << 16L:
- raise ValueError('field 2 out of range (need a 16-bit value)')
- if not 0 <= time_hi_version < 1 << 16L:
- raise ValueError('field 3 out of range (need a 16-bit value)')
- if not 0 <= clock_seq_hi_variant < 1 << 8L:
- raise ValueError('field 4 out of range (need an 8-bit value)')
- if not 0 <= clock_seq_low < 1 << 8L:
- raise ValueError('field 5 out of range (need an 8-bit value)')
- if not 0 <= node < 1 << 48L:
- raise ValueError('field 6 out of range (need a 48-bit value)')
- clock_seq = (clock_seq_hi_variant << 8L) | clock_seq_low
- int = ((time_low << 96L) | (time_mid << 80L) |
- (time_hi_version << 64L) | (clock_seq << 48L) | node)
- if int is not None:
- if not 0 <= int < 1 << 128L:
- raise ValueError('int is out of range (need a 128-bit value)')
- if version is not None:
- if not 1 <= version <= 5:
- raise ValueError('illegal version number')
- # Set the variant to RFC 4122.
- int &= ~(0xc000 << 48L)
- int |= 0x8000 << 48L
- # Set the version number.
- int &= ~(0xf000 << 64L)
- int |= version << 76L
- self.__dict__['int'] = int
-
- def __cmp__(self, other):
- if isinstance(other, UUID):
- return cmp(self.int, other.int)
- return NotImplemented
-
- def __hash__(self):
- return hash(self.int)
-
- def __int__(self):
- return self.int
-
- def __repr__(self):
- return 'UUID(%r)' % str(self)
-
- def __setattr__(self, name, value):
- raise TypeError('UUID objects are immutable')
-
- def __str__(self):
- hex = '%032x' % self.int
- return '%s-%s-%s-%s-%s' % (
- hex[:8], hex[8:12], hex[12:16], hex[16:20], hex[20:])
-
- def get_bytes(self):
- bytes = ''
- for shift in range(0, 128, 8):
- bytes = chr((self.int >> shift) & 0xff) + bytes
- return bytes
-
- bytes = property(get_bytes)
-
- def get_bytes_le(self):
- bytes = self.bytes
- return (bytes[3] + bytes[2] + bytes[1] + bytes[0] +
- bytes[5] + bytes[4] + bytes[7] + bytes[6] + bytes[8:])
-
- bytes_le = property(get_bytes_le)
-
- def get_fields(self):
- return (self.time_low, self.time_mid, self.time_hi_version,
- self.clock_seq_hi_variant, self.clock_seq_low, self.node)
-
- fields = property(get_fields)
-
- def get_time_low(self):
- return self.int >> 96L
-
- time_low = property(get_time_low)
-
- def get_time_mid(self):
- return (self.int >> 80L) & 0xffff
-
- time_mid = property(get_time_mid)
-
- def get_time_hi_version(self):
- return (self.int >> 64L) & 0xffff
-
- time_hi_version = property(get_time_hi_version)
-
- def get_clock_seq_hi_variant(self):
- return (self.int >> 56L) & 0xff
-
- clock_seq_hi_variant = property(get_clock_seq_hi_variant)
-
- def get_clock_seq_low(self):
- return (self.int >> 48L) & 0xff
-
- clock_seq_low = property(get_clock_seq_low)
-
- def get_time(self):
- return (((self.time_hi_version & 0x0fffL) << 48L) |
- (self.time_mid << 32L) | self.time_low)
-
- time = property(get_time)
-
- def get_clock_seq(self):
- return (((self.clock_seq_hi_variant & 0x3fL) << 8L) |
- self.clock_seq_low)
-
- clock_seq = property(get_clock_seq)
-
- def get_node(self):
- return self.int & 0xffffffffffff
-
- node = property(get_node)
-
- def get_hex(self):
- return '%032x' % self.int
-
- hex = property(get_hex)
-
- def get_urn(self):
- return 'urn:uuid:' + str(self)
-
- urn = property(get_urn)
-
- def get_variant(self):
- if not self.int & (0x8000 << 48L):
- return RESERVED_NCS
- elif not self.int & (0x4000 << 48L):
- return RFC_4122
- elif not self.int & (0x2000 << 48L):
- return RESERVED_MICROSOFT
- else:
- return RESERVED_FUTURE
-
- variant = property(get_variant)
-
- def get_version(self):
- # The version bits are only meaningful for RFC 4122 UUIDs.
- if self.variant == RFC_4122:
- return int((self.int >> 76L) & 0xf)
-
- version = property(get_version)
-
-
-def _find_mac(command, args, hw_identifiers, get_index):
- import os
- for dir in ['', '/sbin/', '/usr/sbin']:
- executable = os.path.join(dir, command)
- if not os.path.exists(executable):
- continue
-
- try:
- # LC_ALL to get English output, 2>/dev/null to
- # prevent output on stderr
- cmd = 'LC_ALL=C %s %s 2>/dev/null' % (executable, args)
- pipe = os.popen(cmd)
- except IOError:
- continue
-
- for line in pipe:
- words = line.lower().split()
- for i in range(len(words)):
- if words[i] in hw_identifiers:
- return int(words[get_index(i)].replace(':', ''), 16)
- return None
-
-
-def _ifconfig_getnode():
- """Get the hardware address on Unix by running ifconfig."""
-
- # This works on Linux ('' or '-a'), Tru64 ('-av'), but not all Unixes.
- for args in ('', '-a', '-av'):
- mac = _find_mac('ifconfig', args, ['hwaddr', 'ether'], lambda i: i + 1)
- if mac:
- return mac
-
- import socket
- ip_addr = socket.gethostbyname(socket.gethostname())
-
- # Try getting the MAC addr from arp based on our IP address (Solaris).
- mac = _find_mac('arp', '-an', [ip_addr], lambda i: -1)
- if mac:
- return mac
-
- # This might work on HP-UX.
- mac = _find_mac('lanscan', '-ai', ['lan0'], lambda i: 0)
- if mac:
- return mac
-
- return None
-
-
-def _ipconfig_getnode():
- """Get the hardware address on Windows by running ipconfig.exe."""
- import os
- import re
- dirs = ['', r'c:\windows\system32', r'c:\winnt\system32']
- try:
- import ctypes
- buffer = ctypes.create_string_buffer(300)
- ctypes.windll.kernel32.GetSystemDirectoryA(buffer, 300)
- dirs.insert(0, buffer.value.decode('mbcs'))
- except:
- pass
- for dir in dirs:
- try:
- pipe = os.popen(os.path.join(dir, 'ipconfig') + ' /all')
- except IOError:
- continue
- for line in pipe:
- value = line.split(':')[-1].strip().lower()
- if re.match('([0-9a-f][0-9a-f]-){5}[0-9a-f][0-9a-f]', value):
- return int(value.replace('-', ''), 16)
-
-
-def _netbios_getnode():
- """Get the hardware address on Windows using NetBIOS calls.
- See http://support.microsoft.com/kb/118623 for details."""
- import win32wnet
- import netbios
- ncb = netbios.NCB()
- ncb.Command = netbios.NCBENUM
- ncb.Buffer = adapters = netbios.LANA_ENUM()
- adapters._pack()
- if win32wnet.Netbios(ncb) != 0:
- return
- adapters._unpack()
- for i in range(adapters.length):
- ncb.Reset()
- ncb.Command = netbios.NCBRESET
- ncb.Lana_num = ord(adapters.lana[i])
- if win32wnet.Netbios(ncb) != 0:
- continue
- ncb.Reset()
- ncb.Command = netbios.NCBASTAT
- ncb.Lana_num = ord(adapters.lana[i])
- ncb.Callname = '*'.ljust(16)
- ncb.Buffer = status = netbios.ADAPTER_STATUS()
- if win32wnet.Netbios(ncb) != 0:
- continue
- status._unpack()
- bytes = map(ord, status.adapter_address)
- return ((bytes[0] << 40L) + (bytes[1] << 32L) + (bytes[2] << 24L) +
- (bytes[3] << 16L) + (bytes[4] << 8L) + bytes[5])
-
-# Thanks to Thomas Heller for ctypes and for his help with its use here.
-
-# If ctypes is available, use it to find system routines for UUID generation.
-_uuid_generate_random = _uuid_generate_time = _UuidCreate = None
-try:
- import ctypes
- import ctypes.util
- _buffer = ctypes.create_string_buffer(16)
-
- # The uuid_generate_* routines are provided by libuuid on at least
- # Linux and FreeBSD, and provided by libc on Mac OS X.
- for libname in ['uuid', 'c']:
- try:
- lib = ctypes.CDLL(ctypes.util.find_library(libname))
- except:
- continue
- if hasattr(lib, 'uuid_generate_random'):
- _uuid_generate_random = lib.uuid_generate_random
- if hasattr(lib, 'uuid_generate_time'):
- _uuid_generate_time = lib.uuid_generate_time
-
- # On Windows prior to 2000, UuidCreate gives a UUID containing the
- # hardware address. On Windows 2000 and later, UuidCreate makes a
- # random UUID and UuidCreateSequential gives a UUID containing the
- # hardware address. These routines are provided by the RPC runtime.
- # NOTE: at least on Tim's WinXP Pro SP2 desktop box, while the last
- # 6 bytes returned by UuidCreateSequential are fixed, they don't appear
- # to bear any relationship to the MAC address of any network device
- # on the box.
- try:
- lib = ctypes.windll.rpcrt4
- except:
- lib = None
- _UuidCreate = getattr(lib, 'UuidCreateSequential',
- getattr(lib, 'UuidCreate', None))
-except:
- pass
-
-
-def _unixdll_getnode():
- """Get the hardware address on Unix using ctypes."""
- _uuid_generate_time(_buffer)
- return UUID(bytes=_buffer.raw).node
-
-
-def _windll_getnode():
- """Get the hardware address on Windows using ctypes."""
- if _UuidCreate(_buffer) == 0:
- return UUID(bytes=_buffer.raw).node
-
-
-def _random_getnode():
- """Get a random node ID, with eighth bit set as suggested by RFC 4122."""
- import random
- return random.randrange(0, 1 << 48L) | 0x010000000000L
-
-_node = None
-
-
-def getnode():
- """Get the hardware address as a 48-bit positive integer.
-
- The first time this runs, it may launch a separate program, which could
- be quite slow. If all attempts to obtain the hardware address fail, we
- choose a random 48-bit number with its eighth bit set to 1 as recommended
- in RFC 4122.
- """
-
- global _node
- if _node is not None:
- return _node
-
- import sys
- if sys.platform == 'win32':
- getters = [_windll_getnode, _netbios_getnode, _ipconfig_getnode]
- else:
- getters = [_unixdll_getnode, _ifconfig_getnode]
-
- for getter in getters + [_random_getnode]:
- try:
- _node = getter()
- except:
- continue
- if _node is not None:
- return _node
-
-_last_timestamp = None
-
-
-def uuid1(node=None, clock_seq=None):
- """Generate a UUID from a host ID, sequence number, and the current time.
- If 'node' is not given, getnode() is used to obtain the hardware
- address. If 'clock_seq' is given, it is used as the sequence number;
- otherwise a random 14-bit sequence number is chosen."""
-
- # When the system provides a version-1 UUID generator, use it (but don't
- # use UuidCreate here because its UUIDs don't conform to RFC 4122).
- if _uuid_generate_time and node is clock_seq is None:
- _uuid_generate_time(_buffer)
- return UUID(bytes=_buffer.raw)
-
- global _last_timestamp
- import time
- nanoseconds = int(time.time() * 1e9)
- # 0x01b21dd213814000 is the number of 100-ns intervals between the
- # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
- timestamp = int(nanoseconds / 100) + 0x01b21dd213814000L
- if timestamp <= _last_timestamp:
- timestamp = _last_timestamp + 1
- _last_timestamp = timestamp
- if clock_seq is None:
- import random
- clock_seq = random.randrange(1 << 14L) # instead of stable storage
- time_low = timestamp & 0xffffffffL
- time_mid = (timestamp >> 32L) & 0xffffL
- time_hi_version = (timestamp >> 48L) & 0x0fffL
- clock_seq_low = clock_seq & 0xffL
- clock_seq_hi_variant = (clock_seq >> 8L) & 0x3fL
- if node is None:
- node = getnode()
- return UUID(fields=(time_low, time_mid, time_hi_version,
- clock_seq_hi_variant, clock_seq_low, node), version=1)
-
-
-def uuid3(namespace, name):
- """Generate a UUID from the MD5 hash of a namespace UUID and a name."""
- try:
- import hashlib
- md5 = hashlib.md5
- except ImportError:
- from md5 import md5 # NOQA
- hash = md5(namespace.bytes + name).digest()
- return UUID(bytes=hash[:16], version=3)
-
-
-def uuid4():
- """Generate a random UUID."""
-
- # When the system provides a version-4 UUID generator, use it.
- if _uuid_generate_random:
- _uuid_generate_random(_buffer)
- return UUID(bytes=_buffer.raw)
-
- # Otherwise, get randomness from urandom or the 'random' module.
- try:
- import os
- return UUID(bytes=os.urandom(16), version=4)
- except:
- import random
- bytes = [chr(random.randrange(256)) for i in range(16)]
- return UUID(bytes=bytes, version=4)
-
-
-def uuid5(namespace, name):
- """Generate a UUID from the SHA-1 hash of a namespace UUID and a name."""
- try:
- import hashlib
- sha = hashlib.sha1
- except ImportError:
- from sha import sha # NOQA
- hash = sha(namespace.bytes + name).digest()
- return UUID(bytes=hash[:16], version=5)
-
-# The following standard UUIDs are for use with uuid3() or uuid5().
-
-NAMESPACE_DNS = UUID('6ba7b810-9dad-11d1-80b4-00c04fd430c8')
-NAMESPACE_URL = UUID('6ba7b811-9dad-11d1-80b4-00c04fd430c8')
-NAMESPACE_OID = UUID('6ba7b812-9dad-11d1-80b4-00c04fd430c8')
-NAMESPACE_X500 = UUID('6ba7b814-9dad-11d1-80b4-00c04fd430c8')
diff --git a/awx/lib/site-packages/djcelery/__init__.py b/awx/lib/site-packages/djcelery/__init__.py
index c281cfffdf..fea512138a 100644
--- a/awx/lib/site-packages/djcelery/__init__.py
+++ b/awx/lib/site-packages/djcelery/__init__.py
@@ -5,7 +5,7 @@ from __future__ import absolute_import
import os
-VERSION = (3, 0, 17)
+VERSION = (3, 0, 21)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org'
diff --git a/awx/lib/site-packages/djcelery/admin.py b/awx/lib/site-packages/djcelery/admin.py
index 85d9d4aefe..774b92cef2 100644
--- a/awx/lib/site-packages/djcelery/admin.py
+++ b/awx/lib/site-packages/djcelery/admin.py
@@ -8,7 +8,6 @@ from django.contrib.admin import helpers
from django.contrib.admin.views import main as main_views
from django.shortcuts import render_to_response
from django.template import RequestContext
-from django.utils.encoding import force_unicode
from django.utils.html import escape
from django.utils.translation import ugettext_lazy as _
@@ -22,6 +21,11 @@ from .models import (TaskState, WorkerState,
PeriodicTask, IntervalSchedule, CrontabSchedule)
from .humanize import naturaldate
+try:
+ from django.utils.encoding import force_text
+except ImportError:
+ from django.utils.encoding import force_unicode as force_text
+
TASK_STATE_COLORS = {states.SUCCESS: 'green',
states.FAILURE: 'red',
@@ -175,7 +179,7 @@ class TaskMonitor(ModelMonitor):
context = {
'title': _('Rate limit selection'),
'queryset': queryset,
- 'object_name': force_unicode(opts.verbose_name),
+ 'object_name': force_text(opts.verbose_name),
'action_checkbox_name': helpers.ACTION_CHECKBOX_NAME,
'opts': opts,
'app_label': app_label,
diff --git a/awx/lib/site-packages/djcelery/picklefield.py b/awx/lib/site-packages/djcelery/picklefield.py
index 1f2845df1f..73933e5124 100644
--- a/awx/lib/site-packages/djcelery/picklefield.py
+++ b/awx/lib/site-packages/djcelery/picklefield.py
@@ -20,7 +20,11 @@ from zlib import compress, decompress
from celery.utils.serialization import pickle
from django.db import models
-from django.utils.encoding import force_unicode
+
+try:
+ from django.utils.encoding import force_text
+except ImportError:
+ from django.utils.encoding import force_unicode as force_text
DEFAULT_PROTOCOL = 2
@@ -77,7 +81,7 @@ class PickledObjectField(models.Field):
def get_db_prep_value(self, value, **kwargs):
if value is not None and not isinstance(value, PickledObject):
- return force_unicode(encode(value, self.compress, self.protocol))
+ return force_text(encode(value, self.compress, self.protocol))
return value
def value_to_string(self, obj):
diff --git a/awx/lib/site-packages/djcelery/schedulers.py b/awx/lib/site-packages/djcelery/schedulers.py
index 135fb5a7b5..3e58508b30 100644
--- a/awx/lib/site-packages/djcelery/schedulers.py
+++ b/awx/lib/site-packages/djcelery/schedulers.py
@@ -68,12 +68,12 @@ class ModelEntry(ScheduleEntry):
def _default_now(self):
return self.app.now()
- def next(self):
+ def __next__(self):
self.model.last_run_at = self.app.now()
self.model.total_run_count += 1
self.model.no_changes = True
return self.__class__(self.model)
- __next__ = next # for 2to3
+ next = __next__ # for 2to3
def save(self):
# Object may not be synchronized, so only
diff --git a/awx/lib/site-packages/djcelery/snapshot.py b/awx/lib/site-packages/djcelery/snapshot.py
index c0a112f4ab..cf2ea4dd55 100644
--- a/awx/lib/site-packages/djcelery/snapshot.py
+++ b/awx/lib/site-packages/djcelery/snapshot.py
@@ -10,7 +10,7 @@ from django.conf import settings
from celery import states
from celery.events.state import Task
from celery.events.snapshot import Polaroid
-from celery.utils.timeutils import maybe_iso8601, timezone
+from celery.utils.timeutils import maybe_iso8601
from .models import WorkerState, TaskState
from .utils import maybe_make_aware
@@ -31,7 +31,7 @@ NOT_SAVED_ATTRIBUTES = frozenset(['name', 'args', 'kwargs', 'eta'])
def aware_tstamp(secs):
"""Event timestamps uses the local timezone."""
- return timezone.to_local_fallback(datetime.fromtimestamp(secs))
+ return maybe_make_aware(datetime.fromtimestamp(secs))
class Camera(Polaroid):
diff --git a/awx/lib/site-packages/djcelery/utils.py b/awx/lib/site-packages/djcelery/utils.py
index 02702ef390..fdfc0318ad 100644
--- a/awx/lib/site-packages/djcelery/utils.py
+++ b/awx/lib/site-packages/djcelery/utils.py
@@ -54,7 +54,8 @@ try:
def make_aware(value):
if getattr(settings, 'USE_TZ', False):
# naive datetimes are assumed to be in UTC.
- value = timezone.make_aware(value, timezone.utc)
+ if timezone.is_naive(value):
+ value = timezone.make_aware(value, timezone.utc)
# then convert to the Django configured timezone.
default_tz = timezone.get_default_timezone()
value = timezone.localtime(value, default_tz)
diff --git a/awx/lib/site-packages/kombu/__init__.py b/awx/lib/site-packages/kombu/__init__.py
index 962eae85d0..56d511de34 100644
--- a/awx/lib/site-packages/kombu/__init__.py
+++ b/awx/lib/site-packages/kombu/__init__.py
@@ -1,7 +1,7 @@
"""Messaging Framework for Python"""
from __future__ import absolute_import
-VERSION = (2, 5, 10)
+VERSION = (2, 5, 14)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org'
diff --git a/awx/lib/site-packages/kombu/abstract.py b/awx/lib/site-packages/kombu/abstract.py
index f4218f02b4..a4f0448e0c 100644
--- a/awx/lib/site-packages/kombu/abstract.py
+++ b/awx/lib/site-packages/kombu/abstract.py
@@ -1,6 +1,6 @@
"""
-kombu.compression
-=================
+kombu.abstract
+==============
Object utilities.
diff --git a/awx/lib/site-packages/kombu/connection.py b/awx/lib/site-packages/kombu/connection.py
index 9a6146dc9b..a7a6b15e78 100644
--- a/awx/lib/site-packages/kombu/connection.py
+++ b/awx/lib/site-packages/kombu/connection.py
@@ -151,8 +151,9 @@ class Connection(object):
password=None, virtual_host=None, port=None, insist=False,
ssl=False, transport=None, connect_timeout=5,
transport_options=None, login_method=None, uri_prefix=None,
- heartbeat=0, failover_strategy='round-robin', **kwargs):
- alt = []
+ heartbeat=0, failover_strategy='round-robin',
+ alternates=None, **kwargs):
+ alt = [] if alternates is None else alternates
# have to spell the args out, just to get nice docstrings :(
params = self._initial_params = {
'hostname': hostname, 'userid': userid,
@@ -328,6 +329,29 @@ class Connection(object):
self._debug('closed')
self._closed = True
+ def collect(self, socket_timeout=None):
+ # amqp requires communication to close, we don't need that just
+ # to clear out references, Transport._collect can also be implemented
+ # by other transports that want fast after fork
+ try:
+ gc_transport = self._transport._collect
+ except AttributeError:
+ _timeo = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(socket_timeout)
+ try:
+ self._close()
+ except socket.timeout:
+ pass
+ finally:
+ socket.setdefaulttimeout(_timeo)
+ else:
+ gc_transport(self._connection)
+ if self._transport:
+ self._transport.client = None
+ self._transport = None
+ self.declared_entities.clear()
+ self._connection = None
+
def release(self):
"""Close the connection (if open)."""
self._close()
@@ -522,12 +546,9 @@ class Connection(object):
transport_cls = RESOLVE_ALIASES.get(transport_cls, transport_cls)
D = self.transport.default_connection_params
- if self.alt:
- hostname = ";".join(self.alt)
- else:
- hostname = self.hostname or D.get('hostname')
- if self.uri_prefix:
- hostname = '%s+%s' % (self.uri_prefix, hostname)
+ hostname = self.hostname or D.get('hostname')
+ if self.uri_prefix:
+ hostname = '%s+%s' % (self.uri_prefix, hostname)
info = (('hostname', hostname),
('userid', self.userid or D.get('userid')),
@@ -542,6 +563,10 @@ class Connection(object):
('login_method', self.login_method or D.get('login_method')),
('uri_prefix', self.uri_prefix),
('heartbeat', self.heartbeat))
+
+ if self.alt:
+ info += (('alternates', self.alt),)
+
return info
def info(self):
@@ -910,6 +935,9 @@ class Resource(object):
else:
self.close_resource(resource)
+ def collect_resource(self, resource):
+ pass
+
def force_close_all(self):
"""Closes and removes all resources in the pool (also those in use).
@@ -919,32 +947,27 @@ class Resource(object):
"""
dirty = self._dirty
resource = self._resource
- while 1:
+ while 1: # - acquired
try:
dres = dirty.pop()
except KeyError:
break
try:
- self.close_resource(dres)
+ self.collect_resource(dres)
except AttributeError: # Issue #78
pass
-
- mutex = getattr(resource, 'mutex', None)
- if mutex:
- mutex.acquire()
- try:
- while 1:
- try:
- res = resource.queue.pop()
- except IndexError:
- break
- try:
- self.close_resource(res)
- except AttributeError:
- pass # Issue #78
- finally:
- if mutex: # pragma: no cover
- mutex.release()
+ while 1: # - available
+ # deque supports '.clear', but lists do not, so for that
+ # reason we use pop here, so that the underlying object can
+ # be any object supporting '.pop' and '.append'.
+ try:
+ res = resource.queue.pop()
+ except IndexError:
+ break
+ try:
+ self.collect_resource(res)
+ except AttributeError:
+ pass # Issue #78
if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover
_orig_acquire = acquire
@@ -993,6 +1016,9 @@ class ConnectionPool(Resource):
def close_resource(self, resource):
resource._close()
+ def collect_resource(self, resource, socket_timeout=0.1):
+ return resource.collect(socket_timeout)
+
@contextmanager
def acquire_channel(self, block=False):
with self.acquire(block=block) as connection:
diff --git a/awx/lib/site-packages/kombu/entity.py b/awx/lib/site-packages/kombu/entity.py
index 32b199c377..eaf242e69d 100644
--- a/awx/lib/site-packages/kombu/entity.py
+++ b/awx/lib/site-packages/kombu/entity.py
@@ -123,6 +123,7 @@ class Exchange(MaybeChannelBound):
type = 'direct'
durable = True
auto_delete = False
+ passive = False
delivery_mode = PERSISTENT_DELIVERY_MODE
attrs = (
@@ -130,6 +131,7 @@ class Exchange(MaybeChannelBound):
('type', None),
('arguments', None),
('durable', bool),
+ ('passive', bool),
('auto_delete', bool),
('delivery_mode', lambda m: DELIVERY_MODES.get(m) or m),
)
@@ -143,7 +145,7 @@ class Exchange(MaybeChannelBound):
def __hash__(self):
return hash('E|%s' % (self.name, ))
- def declare(self, nowait=False, passive=False):
+ def declare(self, nowait=False, passive=None):
"""Declare the exchange.
Creates the exchange on the broker.
@@ -152,6 +154,7 @@ class Exchange(MaybeChannelBound):
response will not be waited for. Default is :const:`False`.
"""
+ passive = self.passive if passive is None else passive
if self.name:
return self.channel.exchange_declare(
exchange=self.name, type=self.type, durable=self.durable,
@@ -489,7 +492,7 @@ class Queue(MaybeChannelBound):
self.exchange.declare(nowait)
self.queue_declare(nowait, passive=False)
- if self.exchange is not None:
+ if self.exchange and self.exchange.name:
self.queue_bind(nowait)
# - declare extra/multi-bindings.
@@ -541,8 +544,8 @@ class Queue(MaybeChannelBound):
Returns the message instance if a message was available,
or :const:`None` otherwise.
- :keyword no_ack: If set messages received does not have to
- be acknowledged.
+ :keyword no_ack: If enabled the broker will automatically
+ ack messages.
This method provides direct access to the messages in a
queue using a synchronous dialogue, designed for
@@ -575,8 +578,8 @@ class Queue(MaybeChannelBound):
can use the same consumer tags. If this field is empty
the server will generate a unique tag.
- :keyword no_ack: If set messages received does not have to
- be acknowledged.
+ :keyword no_ack: If enabled the broker will automatically ack
+ messages.
:keyword nowait: Do not wait for a reply.
diff --git a/awx/lib/site-packages/kombu/messaging.py b/awx/lib/site-packages/kombu/messaging.py
index 3067fc7b34..59be5bd085 100644
--- a/awx/lib/site-packages/kombu/messaging.py
+++ b/awx/lib/site-packages/kombu/messaging.py
@@ -276,8 +276,12 @@ class Consumer(object):
#: consume from.
queues = None
- #: Flag for message acknowledgment disabled/enabled.
- #: Enabled by default.
+ #: Flag for automatic message acknowledgment.
+ #: If enabled the messages are automatically acknowledged by the
+ #: broker. This can increase performance but means that you
+ #: have no control of when the message is removed.
+ #:
+ #: Disabled by default.
no_ack = None
#: By default all entities will be declared at instantiation, if you
@@ -399,6 +403,12 @@ class Consumer(object):
pass
def add_queue(self, queue):
+ """Add a queue to the list of queues to consume from.
+
+ This will not start consuming from the queue,
+ for that you will have to call :meth:`consume` after.
+
+ """
queue = queue(self.channel)
if self.auto_declare:
queue.declare()
@@ -406,9 +416,26 @@ class Consumer(object):
return queue
def add_queue_from_dict(self, queue, **options):
+ """This method is deprecated.
+
+ Instead please use::
+
+ consumer.add_queue(Queue.from_dict(d))
+
+ """
return self.add_queue(Queue.from_dict(queue, **options))
def consume(self, no_ack=None):
+ """Start consuming messages.
+
+ Can be called multiple times, but note that while it
+ will consume from new queues added since the last call,
+ it will not cancel consuming from removed queues (
+ use :meth:`cancel_by_queue`).
+
+ :param no_ack: See :attr:`no_ack`.
+
+ """
if self.queues:
no_ack = self.no_ack if no_ack is None else no_ack
@@ -441,10 +468,12 @@ class Consumer(object):
self.channel.basic_cancel(tag)
def consuming_from(self, queue):
+ """Returns :const:`True` if the consumer is currently
+ consuming from queue'."""
name = queue
if isinstance(queue, Queue):
name = queue.name
- return any(q.name == name for q in self.queues)
+ return name in self._active_tags
def purge(self):
"""Purge messages from all queues.
diff --git a/awx/lib/site-packages/kombu/mixins.py b/awx/lib/site-packages/kombu/mixins.py
index 82b9733e92..36f4817911 100644
--- a/awx/lib/site-packages/kombu/mixins.py
+++ b/awx/lib/site-packages/kombu/mixins.py
@@ -13,6 +13,7 @@ import socket
from contextlib import contextmanager
from functools import partial
from itertools import count
+from time import sleep
from .common import ignore_errors
from .messaging import Consumer
@@ -158,13 +159,18 @@ class ConsumerMixin(object):
def extra_context(self, connection, channel):
yield
- def run(self):
+ def run(self, _tokens=1):
+ restart_limit = self.restart_limit
+ errors = (self.connection.connection_errors +
+ self.connection.channel_errors)
while not self.should_stop:
try:
- if self.restart_limit.can_consume(1):
+ if restart_limit.can_consume(_tokens):
for _ in self.consume(limit=None):
pass
- except self.connection.connection_errors:
+ else:
+ sleep(restart_limit.expected_time(_tokens))
+ except errors:
warn('Connection to broker lost. '
'Trying to re-establish the connection...')
diff --git a/awx/lib/site-packages/kombu/pidbox.py b/awx/lib/site-packages/kombu/pidbox.py
index cc29ce1a99..0b42c809bb 100644
--- a/awx/lib/site-packages/kombu/pidbox.py
+++ b/awx/lib/site-packages/kombu/pidbox.py
@@ -20,6 +20,7 @@ from time import time
from . import Exchange, Queue, Consumer, Producer
from .clocks import LamportClock
from .common import maybe_declare, oid_from
+from .exceptions import InconsistencyError
from .utils import cached_property, kwdict, uuid
REPLY_QUEUE_EXPIRES = 10
@@ -215,9 +216,15 @@ class Mailbox(object):
delivery_mode='transient',
durable=False)
producer = Producer(chan, auto_declare=False)
- producer.publish(reply, exchange=exchange, routing_key=routing_key,
- declare=[exchange], headers={
- 'ticket': ticket, 'clock': self.clock.forward()})
+ try:
+ producer.publish(
+ reply, exchange=exchange, routing_key=routing_key,
+ declare=[exchange], headers={
+ 'ticket': ticket, 'clock': self.clock.forward(),
+ },
+ )
+ except InconsistencyError:
+ pass # queue probably deleted and no one is expecting a reply.
def _publish(self, type, arguments, destination=None,
reply_ticket=None, channel=None, timeout=None):
diff --git a/awx/lib/site-packages/kombu/tests/__init__.py b/awx/lib/site-packages/kombu/tests/__init__.py
index 6a13a75074..fb9f21a210 100644
--- a/awx/lib/site-packages/kombu/tests/__init__.py
+++ b/awx/lib/site-packages/kombu/tests/__init__.py
@@ -1,6 +1,7 @@
from __future__ import absolute_import
import anyjson
+import atexit
import os
import sys
@@ -14,6 +15,24 @@ except ImportError:
anyjson.force_implementation('simplejson')
+def teardown():
+ # Workaround for multiprocessing bug where logging
+ # is attempted after global already collected at shutdown.
+ cancelled = set()
+ try:
+ import multiprocessing.util
+ cancelled.add(multiprocessing.util._exit_function)
+ except (AttributeError, ImportError):
+ pass
+
+ try:
+ atexit._exithandlers[:] = [
+ e for e in atexit._exithandlers if e[0] not in cancelled
+ ]
+ except AttributeError: # pragma: no cover
+ pass # Py3 missing _exithandlers
+
+
def find_distribution_modules(name=__name__, file=__file__):
current_dist_depth = len(name.split('.')) - 1
current_dist = os.path.join(os.path.dirname(file),
diff --git a/awx/lib/site-packages/kombu/tests/test_entities.py b/awx/lib/site-packages/kombu/tests/test_entities.py
index 8e89f84ea8..f2d54756b9 100644
--- a/awx/lib/site-packages/kombu/tests/test_entities.py
+++ b/awx/lib/site-packages/kombu/tests/test_entities.py
@@ -130,6 +130,10 @@ class test_Exchange(TestCase):
exc = Exchange('foo', 'direct', delivery_mode='transient')
self.assertEqual(exc.delivery_mode, Exchange.TRANSIENT_DELIVERY_MODE)
+ def test_set_passive_mode(self):
+ exc = Exchange('foo', 'direct', passive=True)
+ self.assertTrue(exc.passive)
+
def test_set_persistent_delivery_mode(self):
exc = Exchange('foo', 'direct', delivery_mode='persistent')
self.assertEqual(exc.delivery_mode, Exchange.PERSISTENT_DELIVERY_MODE)
diff --git a/awx/lib/site-packages/kombu/tests/test_messaging.py b/awx/lib/site-packages/kombu/tests/test_messaging.py
index 0fb8a65751..477e4e9970 100644
--- a/awx/lib/site-packages/kombu/tests/test_messaging.py
+++ b/awx/lib/site-packages/kombu/tests/test_messaging.py
@@ -258,9 +258,13 @@ class test_Consumer(TestCase):
def test_consuming_from(self):
consumer = self.connection.Consumer()
- consumer.queues[:] = [Queue('a'), Queue('b')]
+ consumer.queues[:] = [Queue('a'), Queue('b'), Queue('d')]
+ consumer._active_tags = {'a': 1, 'b': 2}
+
self.assertFalse(consumer.consuming_from(Queue('c')))
self.assertFalse(consumer.consuming_from('c'))
+ self.assertFalse(consumer.consuming_from(Queue('d')))
+ self.assertFalse(consumer.consuming_from('d'))
self.assertTrue(consumer.consuming_from(Queue('a')))
self.assertTrue(consumer.consuming_from(Queue('b')))
self.assertTrue(consumer.consuming_from('b'))
diff --git a/awx/lib/site-packages/kombu/tests/test_serialization.py b/awx/lib/site-packages/kombu/tests/test_serialization.py
index de7ed34312..d87a37f348 100644
--- a/awx/lib/site-packages/kombu/tests/test_serialization.py
+++ b/awx/lib/site-packages/kombu/tests/test_serialization.py
@@ -5,6 +5,8 @@ from __future__ import with_statement
import sys
+from base64 import b64decode
+
from kombu.serialization import (registry, register, SerializerNotInstalled,
raw_encode, register_yaml, register_msgpack,
decode, bytes_t, pickle, pickle_protocol,
@@ -53,16 +55,13 @@ unicode: "Th\\xE9 quick brown fox jumps over th\\xE9 lazy dog"
msgpack_py_data = dict(py_data)
-# msgpack only supports tuples
-msgpack_py_data['list'] = tuple(msgpack_py_data['list'])
# Unicode chars are lost in transmit :(
msgpack_py_data['unicode'] = 'Th quick brown fox jumps over th lazy dog'
-msgpack_data = """\
-\x85\xa3int\n\xa5float\xcb@\t!\xfbS\xc8\xd4\xf1\xa4list\
-\x94\xa6george\xa5jerry\xa6elaine\xa5cosmo\xa6string\xda\
-\x00+The quick brown fox jumps over the lazy dog\xa7unicode\
-\xda\x00)Th quick brown fox jumps over th lazy dog\
-"""
+msgpack_data = b64decode("""\
+haNpbnQKpWZsb2F0y0AJIftTyNTxpGxpc3SUpmdlb3JnZaVqZXJyeaZlbGFpbmWlY29zbW+mc3Rya\
+W5n2gArVGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIHRoZSBsYXp5IGRvZ6d1bmljb2Rl2g\
+ApVGggcXVpY2sgYnJvd24gZm94IGp1bXBzIG92ZXIgdGggbGF6eSBkb2c=\
+""")
def say(m):
diff --git a/awx/lib/site-packages/kombu/tests/test_utils.py b/awx/lib/site-packages/kombu/tests/test_utils.py
index 7090546c64..c718c9e59c 100644
--- a/awx/lib/site-packages/kombu/tests/test_utils.py
+++ b/awx/lib/site-packages/kombu/tests/test_utils.py
@@ -215,7 +215,7 @@ class test_retry_over_time(TestCase):
self.myfun, self.Predicate,
max_retries=1, errback=self.errback, interval_max=14,
)
- self.assertEqual(self.index, 2)
+ self.assertEqual(self.index, 1)
# no errback
self.assertRaises(
self.Predicate, utils.retry_over_time,
@@ -230,7 +230,7 @@ class test_retry_over_time(TestCase):
self.myfun, self.Predicate,
max_retries=0, errback=self.errback, interval_max=14,
)
- self.assertEqual(self.index, 1)
+ self.assertEqual(self.index, 0)
class test_cached_property(TestCase):
diff --git a/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py b/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py
index 745c014254..6840e407f8 100644
--- a/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py
+++ b/awx/lib/site-packages/kombu/tests/transport/test_pyamqp.py
@@ -3,8 +3,10 @@ from __future__ import with_statement
import sys
+from functools import partial
from mock import patch
from nose import SkipTest
+from itertools import count
try:
import amqp # noqa
@@ -43,6 +45,7 @@ class test_Channel(TestCase):
pass
self.conn = Mock()
+ self.conn._get_free_channel_id.side_effect = partial(next, count(0))
self.conn.channels = {}
self.channel = Channel(self.conn, 0)
diff --git a/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py b/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py
index b2942b671d..1bac68e6a6 100644
--- a/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py
+++ b/awx/lib/site-packages/kombu/tests/utilities/test_encoding.py
@@ -8,7 +8,7 @@ from contextlib import contextmanager
from mock import patch
from nose import SkipTest
-from kombu.utils.encoding import safe_str
+from kombu.utils.encoding import bytes_t, safe_str, default_encoding
from kombu.tests.utils import TestCase
@@ -26,16 +26,16 @@ def clean_encoding():
class test_default_encoding(TestCase):
- @patch('sys.getfilesystemencoding')
- def test_default(self, getfilesystemencoding):
- getfilesystemencoding.return_value = 'ascii'
+ @patch('sys.getdefaultencoding')
+ def test_default(self, getdefaultencoding):
+ getdefaultencoding.return_value = 'ascii'
with clean_encoding() as encoding:
enc = encoding.default_encoding()
if sys.platform.startswith('java'):
self.assertEqual(enc, 'utf-8')
else:
self.assertEqual(enc, 'ascii')
- getfilesystemencoding.assert_called_with()
+ getdefaultencoding.assert_called_with()
class test_encoding_utils(TestCase):
@@ -60,16 +60,36 @@ class test_encoding_utils(TestCase):
class test_safe_str(TestCase):
- def test_when_str(self):
+ def setUp(self):
+ self._cencoding = patch('sys.getdefaultencoding')
+ self._encoding = self._cencoding.__enter__()
+ self._encoding.return_value = 'ascii'
+
+ def tearDown(self):
+ self._cencoding.__exit__()
+
+ def test_when_bytes(self):
self.assertEqual(safe_str('foo'), 'foo')
def test_when_unicode(self):
- self.assertIsInstance(safe_str(u'foo'), str)
+ self.assertIsInstance(safe_str(u'foo'), bytes_t)
+
+ def test_when_encoding_utf8(self):
+ with patch('sys.getdefaultencoding') as encoding:
+ encoding.return_value = 'utf-8'
+ self.assertEqual(default_encoding(), 'utf-8')
+ s = u'The quiæk fåx jømps øver the lazy dåg'
+ res = safe_str(s)
+ self.assertIsInstance(res, bytes_t)
+ self.assertGreater(len(res), len(s))
def test_when_containing_high_chars(self):
- s = u'The quiæk fåx jømps øver the lazy dåg'
- res = safe_str(s)
- self.assertIsInstance(res, str)
+ with patch('sys.getdefaultencoding') as encoding:
+ encoding.return_value = 'ascii'
+ s = u'The quiæk fåx jømps øver the lazy dåg'
+ res = safe_str(s)
+ self.assertIsInstance(res, bytes_t)
+ self.assertEqual(len(s), len(res))
def test_when_not_string(self):
o = object()
diff --git a/awx/lib/site-packages/kombu/transport/librabbitmq.py b/awx/lib/site-packages/kombu/transport/librabbitmq.py
index acfa33e701..0909709198 100644
--- a/awx/lib/site-packages/kombu/transport/librabbitmq.py
+++ b/awx/lib/site-packages/kombu/transport/librabbitmq.py
@@ -9,6 +9,7 @@ kombu.transport.librabbitmq
"""
from __future__ import absolute_import
+import os
import socket
try:
@@ -28,6 +29,10 @@ from . import base
DEFAULT_PORT = 5672
+NO_SSL_ERROR = """\
+ssl not supported by librabbitmq, please use pyamqp:// or stunnel\
+"""
+
class Message(base.Message):
@@ -98,22 +103,41 @@ class Transport(base.Transport):
for name, default_value in self.default_connection_params.items():
if not getattr(conninfo, name, None):
setattr(conninfo, name, default_value)
- conn = self.Connection(host=conninfo.host,
- userid=conninfo.userid,
- password=conninfo.password,
- virtual_host=conninfo.virtual_host,
- login_method=conninfo.login_method,
- insist=conninfo.insist,
- ssl=conninfo.ssl,
- connect_timeout=conninfo.connect_timeout)
+ if conninfo.ssl:
+ raise NotImplementedError(NO_SSL_ERROR)
+ opts = dict({
+ 'host': conninfo.host,
+ 'userid': conninfo.userid,
+ 'password': conninfo.password,
+ 'virtual_host': conninfo.virtual_host,
+ 'login_method': conninfo.login_method,
+ 'insist': conninfo.insist,
+ 'ssl': conninfo.ssl,
+ 'connect_timeout': conninfo.connect_timeout,
+ }, **conninfo.transport_options or {})
+ conn = self.Connection(**opts)
conn.client = self.client
self.client.drain_events = conn.drain_events
return conn
def close_connection(self, connection):
"""Close the AMQP broker connection."""
+ self.client.drain_events = None
connection.close()
+ def _collect(self, connection):
+ if connection is not None:
+ for channel in connection.channels.itervalues():
+ channel.connection = None
+ try:
+ os.close(connection.fileno())
+ except OSError:
+ pass
+ connection.channels.clear()
+ connection.callbacks.clear()
+ self.client.drain_events = None
+ self.client = None
+
def verify_connection(self, connection):
return connection.connected
diff --git a/awx/lib/site-packages/kombu/transport/mongodb.py b/awx/lib/site-packages/kombu/transport/mongodb.py
index 1a09d1f42b..ff0ff3e811 100644
--- a/awx/lib/site-packages/kombu/transport/mongodb.py
+++ b/awx/lib/site-packages/kombu/transport/mongodb.py
@@ -101,57 +101,42 @@ class Channel(virtual.Channel):
See mongodb uri documentation:
http://www.mongodb.org/display/DOCS/Connections
"""
- conninfo = self.connection.client
+ client = self.connection.client
+ hostname = client.hostname or DEFAULT_HOST
+ authdb = dbname = client.virtual_host
- dbname = None
- hostname = None
-
- if not conninfo.hostname:
- conninfo.hostname = DEFAULT_HOST
-
- for part in conninfo.hostname.split('/'):
- if not hostname:
- hostname = 'mongodb://' + part
- continue
-
- dbname = part
- if '?' in part:
- # In case someone is passing options
- # to the mongodb connection. Right now
- # it is not permitted by kombu
- dbname, options = part.split('?')
- hostname += '/?' + options
-
- hostname = "%s/%s" % (
- hostname, dbname in [None, "/"] and "admin" or dbname,
- )
- if not dbname or dbname == "/":
+ if dbname in ["/", None]:
dbname = "kombu_default"
+ authdb = "admin"
+ if not client.userid:
+ hostname = hostname.replace('/' + client.virtual_host, '/')
+ else:
+ hostname = hostname.replace('/' + client.virtual_host,
+ '/' + authdb)
+
+ mongo_uri = 'mongodb://' + hostname
# At this point we expect the hostname to be something like
# (considering replica set form too):
#
# mongodb://[username:password@]host1[:port1][,host2[:port2],
# ...[,hostN[:portN]]][/[?options]]
- mongoconn = Connection(host=hostname, ssl=conninfo.ssl)
+ mongoconn = Connection(host=mongo_uri, ssl=client.ssl)
+ database = getattr(mongoconn, dbname)
+
version = mongoconn.server_info()['version']
if tuple(map(int, version.split('.')[:2])) < (1, 3):
raise NotImplementedError(
'Kombu requires MongoDB version 1.3+, but connected to %s' % (
version, ))
- database = getattr(mongoconn, dbname)
-
- # This is done by the connection uri
- # if conninfo.userid:
- # database.authenticate(conninfo.userid, conninfo.password)
self.db = database
col = database.messages
col.ensure_index([('queue', 1), ('_id', 1)], background=True)
if 'messages.broadcast' not in database.collection_names():
- capsize = conninfo.transport_options.get(
- 'capped_queue_size') or 100000
+ capsize = (client.transport_options.get('capped_queue_size')
+ or 100000)
database.create_collection('messages.broadcast',
size=capsize, capped=True)
diff --git a/awx/lib/site-packages/kombu/transport/pyamqp.py b/awx/lib/site-packages/kombu/transport/pyamqp.py
index 8eb62dcccb..8b066e7b64 100644
--- a/awx/lib/site-packages/kombu/transport/pyamqp.py
+++ b/awx/lib/site-packages/kombu/transport/pyamqp.py
@@ -75,14 +75,17 @@ class Transport(base.Transport):
channel_errors = (StdChannelError, ) + amqp.Connection.channel_errors
nb_keep_draining = True
- driver_name = "py-amqp"
- driver_type = "amqp"
+ driver_name = 'py-amqp'
+ driver_type = 'amqp'
supports_heartbeats = True
supports_ev = True
def __init__(self, client, **kwargs):
self.client = client
- self.default_port = kwargs.get("default_port") or self.default_port
+ self.default_port = kwargs.get('default_port') or self.default_port
+
+ def driver_version(self):
+ return amqp.__version__
def create_channel(self, connection):
return connection.channel()
@@ -98,15 +101,18 @@ class Transport(base.Transport):
setattr(conninfo, name, default_value)
if conninfo.hostname == 'localhost':
conninfo.hostname = '127.0.0.1'
- conn = self.Connection(host=conninfo.host,
- userid=conninfo.userid,
- password=conninfo.password,
- login_method=conninfo.login_method,
- virtual_host=conninfo.virtual_host,
- insist=conninfo.insist,
- ssl=conninfo.ssl,
- connect_timeout=conninfo.connect_timeout,
- heartbeat=conninfo.heartbeat)
+ opts = dict({
+ 'host': conninfo.host,
+ 'userid': conninfo.userid,
+ 'password': conninfo.password,
+ 'login_method': conninfo.login_method,
+ 'virtual_host': conninfo.virtual_host,
+ 'insist': conninfo.insist,
+ 'ssl': conninfo.ssl,
+ 'connect_timeout': conninfo.connect_timeout,
+ 'heartbeat': conninfo.heartbeat,
+ }, **conninfo.transport_options or {})
+ conn = self.Connection(**opts)
conn.client = self.client
return conn
diff --git a/awx/lib/site-packages/kombu/transport/redis.py b/awx/lib/site-packages/kombu/transport/redis.py
index bf9e9f49ad..687b57c586 100644
--- a/awx/lib/site-packages/kombu/transport/redis.py
+++ b/awx/lib/site-packages/kombu/transport/redis.py
@@ -206,7 +206,7 @@ class MultiChannelPoller(object):
for fd in self._chan_to_sock.itervalues():
try:
self.poller.unregister(fd)
- except KeyError:
+ except (KeyError, ValueError):
pass
self._channels.clear()
self._fd_to_chan.clear()
@@ -707,11 +707,12 @@ class Channel(virtual.Channel):
return self._queue_cycle[0:active]
def _rotate_cycle(self, used):
- """
- Move most recently used queue to end of list
- """
- index = self._queue_cycle.index(used)
- self._queue_cycle.append(self._queue_cycle.pop(index))
+ """Move most recently used queue to end of list."""
+ cycle = self._queue_cycle
+ try:
+ cycle.append(cycle.pop(cycle.index(used)))
+ except ValueError:
+ pass
def _get_response_error(self):
from redis import exceptions
diff --git a/awx/lib/site-packages/kombu/utils/__init__.py b/awx/lib/site-packages/kombu/utils/__init__.py
index 532fb883b4..c52382641a 100644
--- a/awx/lib/site-packages/kombu/utils/__init__.py
+++ b/awx/lib/site-packages/kombu/utils/__init__.py
@@ -216,7 +216,7 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None,
try:
return fun(*args, **kwargs)
except catch, exc:
- if max_retries is not None and retries > max_retries:
+ if max_retries is not None and retries >= max_retries:
raise
if callback:
callback()
diff --git a/awx/lib/site-packages/kombu/utils/encoding.py b/awx/lib/site-packages/kombu/utils/encoding.py
index efeeb32957..a4a1698938 100644
--- a/awx/lib/site-packages/kombu/utils/encoding.py
+++ b/awx/lib/site-packages/kombu/utils/encoding.py
@@ -22,7 +22,7 @@ if sys.platform.startswith('java'): # pragma: no cover
else:
def default_encoding(): # noqa
- return sys.getfilesystemencoding()
+ return sys.getdefaultencoding()
if is_py3k: # pragma: no cover
diff --git a/awx/lib/site-packages/kombu/utils/eventio.py b/awx/lib/site-packages/kombu/utils/eventio.py
index e9acc923e8..80b1501f9e 100644
--- a/awx/lib/site-packages/kombu/utils/eventio.py
+++ b/awx/lib/site-packages/kombu/utils/eventio.py
@@ -10,7 +10,7 @@ from __future__ import absolute_import
import errno
import socket
-from select import select as _selectf
+from select import select as _selectf, error as _selecterr
try:
from select import epoll
@@ -53,6 +53,11 @@ READ = POLL_READ = 0x001
WRITE = POLL_WRITE = 0x004
ERR = POLL_ERR = 0x008 | 0x010
+try:
+ SELECT_BAD_FD = set((errno.EBADF, errno.WSAENOTSOCK))
+except AttributeError:
+ SELECT_BAD_FD = set((errno.EBADF,))
+
class Poller(object):
@@ -79,11 +84,9 @@ class _epoll(Poller):
def unregister(self, fd):
try:
self._epoll.unregister(fd)
- except socket.error:
- pass
- except ValueError:
+ except (socket.error, ValueError, KeyError):
pass
- except IOError, exc:
+ except (IOError, OSError), exc:
if get_errno(exc) != errno.ENOENT:
raise
@@ -191,13 +194,31 @@ class _select(Poller):
if events & READ:
self._rfd.add(fd)
+ def _remove_bad(self):
+ for fd in self._rfd | self._wfd | self._efd:
+ try:
+ _selectf([fd], [], [], 0)
+ except (_selecterr, socket.error), exc:
+ if get_errno(exc) in SELECT_BAD_FD:
+ self.unregister(fd)
+
def unregister(self, fd):
self._rfd.discard(fd)
self._wfd.discard(fd)
self._efd.discard(fd)
def _poll(self, timeout):
- read, write, error = _selectf(self._rfd, self._wfd, self._efd, timeout)
+ try:
+ read, write, error = _selectf(
+ self._rfd, self._wfd, self._efd, timeout,
+ )
+ except (_selecterr, socket.error), exc:
+ if get_errno(exc) == errno.EINTR:
+ return
+ elif get_errno(exc) in SELECT_BAD_FD:
+ return self._remove_bad()
+ raise
+
events = {}
for fd in read:
if not isinstance(fd, int):
diff --git a/awx/lib/site-packages/rest_framework/__init__.py b/awx/lib/site-packages/rest_framework/__init__.py
index 0a21018634..087808e0b2 100644
--- a/awx/lib/site-packages/rest_framework/__init__.py
+++ b/awx/lib/site-packages/rest_framework/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '2.3.5'
+__version__ = '2.3.7'
VERSION = __version__ # synonym
diff --git a/awx/lib/site-packages/rest_framework/authentication.py b/awx/lib/site-packages/rest_framework/authentication.py
index 9caca78894..cf001a24dd 100644
--- a/awx/lib/site-packages/rest_framework/authentication.py
+++ b/awx/lib/site-packages/rest_framework/authentication.py
@@ -3,14 +3,13 @@ Provides various authentication policies.
"""
from __future__ import unicode_literals
import base64
-from datetime import datetime
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider
+from rest_framework.compat import oauth2_provider, provider_now
from rest_framework.authtoken.models import Token
@@ -27,6 +26,12 @@ def get_authorization_header(request):
return auth
+class CSRFCheck(CsrfViewMiddleware):
+ def _reject(self, request, reason):
+ # Return the failure reason instead of an HttpResponse
+ return reason
+
+
class BaseAuthentication(object):
"""
All authentication classes should extend BaseAuthentication.
@@ -104,27 +109,27 @@ class SessionAuthentication(BaseAuthentication):
"""
# Get the underlying HttpRequest object
- http_request = request._request
- user = getattr(http_request, 'user', None)
+ request = request._request
+ user = getattr(request, 'user', None)
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
return None
- # Enforce CSRF validation for session based authentication.
- class CSRFCheck(CsrfViewMiddleware):
- def _reject(self, request, reason):
- # Return the failure reason instead of an HttpResponse
- return reason
+ self.enforce_csrf(request)
+
+ # CSRF passed with authenticated user
+ return (user, None)
- reason = CSRFCheck().process_view(http_request, None, (), {})
+ def enforce_csrf(self, request):
+ """
+ Enforce CSRF validation for session based authentication.
+ """
+ reason = CSRFCheck().process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
- # CSRF passed with authenticated user
- return (user, None)
-
class TokenAuthentication(BaseAuthentication):
"""
@@ -230,8 +235,9 @@ class OAuthAuthentication(BaseAuthentication):
try:
consumer_key = oauth_request.get_parameter('oauth_consumer_key')
consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
- except oauth_provider.store.InvalidConsumerError as err:
- raise exceptions.AuthenticationFailed(err)
+ except oauth_provider.store.InvalidConsumerError:
+ msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
+ raise exceptions.AuthenticationFailed(msg)
if consumer.status != oauth_provider.consts.ACCEPTED:
msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
@@ -319,9 +325,9 @@ class OAuth2Authentication(BaseAuthentication):
try:
token = oauth2_provider.models.AccessToken.objects.select_related('user')
- # TODO: Change to timezone aware datetime when oauth2_provider add
- # support to it.
- token = token.get(token=access_token, expires__gt=datetime.now())
+ # provider_now switches to timezone aware datetime when
+ # the oauth2_provider version supports to it.
+ token = token.get(token=access_token, expires__gt=provider_now())
except oauth2_provider.models.AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
diff --git a/awx/lib/site-packages/rest_framework/authtoken/admin.py b/awx/lib/site-packages/rest_framework/authtoken/admin.py
new file mode 100644
index 0000000000..ec28eb1ca2
--- /dev/null
+++ b/awx/lib/site-packages/rest_framework/authtoken/admin.py
@@ -0,0 +1,11 @@
+from django.contrib import admin
+from rest_framework.authtoken.models import Token
+
+
+class TokenAdmin(admin.ModelAdmin):
+ list_display = ('key', 'user', 'created')
+ fields = ('user',)
+ ordering = ('-created',)
+
+
+admin.site.register(Token, TokenAdmin)
diff --git a/awx/lib/site-packages/rest_framework/authtoken/models.py b/awx/lib/site-packages/rest_framework/authtoken/models.py
index 52c45ad11f..7601f5b791 100644
--- a/awx/lib/site-packages/rest_framework/authtoken/models.py
+++ b/awx/lib/site-packages/rest_framework/authtoken/models.py
@@ -1,7 +1,7 @@
import uuid
import hmac
from hashlib import sha1
-from rest_framework.compat import User
+from rest_framework.compat import AUTH_USER_MODEL
from django.conf import settings
from django.db import models
@@ -11,7 +11,7 @@ class Token(models.Model):
The default authorization token model.
"""
key = models.CharField(max_length=40, primary_key=True)
- user = models.OneToOneField(User, related_name='auth_token')
+ user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)
class Meta:
diff --git a/awx/lib/site-packages/rest_framework/compat.py b/awx/lib/site-packages/rest_framework/compat.py
index 76dc00526c..6f7447add0 100644
--- a/awx/lib/site-packages/rest_framework/compat.py
+++ b/awx/lib/site-packages/rest_framework/compat.py
@@ -2,11 +2,13 @@
The `compat` module provides support for backwards compatibility with older
versions of django/python, and compatibility wrappers around optional packages.
"""
+
# flake8: noqa
from __future__ import unicode_literals
import django
from django.core.exceptions import ImproperlyConfigured
+from django.conf import settings
# Try to import six from Django, fallback to included `six`.
try:
@@ -33,6 +35,12 @@ except ImportError:
from django.utils.encoding import force_unicode as force_text
+# HttpResponseBase only exists from 1.5 onwards
+try:
+ from django.http.response import HttpResponseBase
+except ImportError:
+ from django.http import HttpResponse as HttpResponseBase
+
# django-filter is optional
try:
import django_filters
@@ -76,16 +84,9 @@ def get_concrete_model(model_cls):
# Django 1.5 add support for custom auth user model
if django.VERSION >= (1, 5):
- from django.conf import settings
- if hasattr(settings, 'AUTH_USER_MODEL'):
- User = settings.AUTH_USER_MODEL
- else:
- from django.contrib.auth.models import User
+ AUTH_USER_MODEL = settings.AUTH_USER_MODEL
else:
- try:
- from django.contrib.auth.models import User
- except ImportError:
- raise ImportError("User model is not to be found.")
+ AUTH_USER_MODEL = 'auth.User'
if django.VERSION >= (1, 5):
@@ -435,6 +436,42 @@ except ImportError:
return force_text(url)
+# RequestFactory only provide `generic` from 1.5 onwards
+
+from django.test.client import RequestFactory as DjangoRequestFactory
+from django.test.client import FakePayload
+try:
+ # In 1.5 the test client uses force_bytes
+ from django.utils.encoding import force_bytes_or_smart_bytes
+except ImportError:
+ # In 1.3 and 1.4 the test client just uses smart_str
+ from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
+
+
+class RequestFactory(DjangoRequestFactory):
+ def generic(self, method, path,
+ data='', content_type='application/octet-stream', **extra):
+ parsed = urlparse.urlparse(path)
+ data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
+ r = {
+ 'PATH_INFO': self._get_path(parsed),
+ 'QUERY_STRING': force_text(parsed[4]),
+ 'REQUEST_METHOD': str(method),
+ }
+ if data:
+ r.update({
+ 'CONTENT_LENGTH': len(data),
+ 'CONTENT_TYPE': str(content_type),
+ 'wsgi.input': FakePayload(data),
+ })
+ elif django.VERSION <= (1, 4):
+ # For 1.3 we need an empty WSGI payload
+ r.update({
+ 'wsgi.input': FakePayload('')
+ })
+ r.update(extra)
+ return self.request(**r)
+
# Markdown is optional
try:
import markdown
@@ -489,12 +526,22 @@ try:
from provider.oauth2 import forms as oauth2_provider_forms
from provider import scope as oauth2_provider_scope
from provider import constants as oauth2_constants
+ from provider import __version__ as provider_version
+ if provider_version in ('0.2.3', '0.2.4'):
+ # 0.2.3 and 0.2.4 are supported version that do not support
+ # timezone aware datetimes
+ import datetime
+ provider_now = datetime.datetime.now
+ else:
+ # Any other supported version does use timezone aware datetimes
+ from django.utils.timezone import now as provider_now
except ImportError:
oauth2_provider = None
oauth2_provider_models = None
oauth2_provider_forms = None
oauth2_provider_scope = None
oauth2_constants = None
+ provider_now = None
# Handle lazy strings
from django.utils.functional import Promise
diff --git a/awx/lib/site-packages/rest_framework/exceptions.py b/awx/lib/site-packages/rest_framework/exceptions.py
index 0c96ecdd52..425a721499 100644
--- a/awx/lib/site-packages/rest_framework/exceptions.py
+++ b/awx/lib/site-packages/rest_framework/exceptions.py
@@ -86,10 +86,3 @@ class Throttled(APIException):
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
else:
self.detail = detail or self.default_detail
-
-
-class ConfigurationError(Exception):
- """
- Indicates an internal server error.
- """
- pass
diff --git a/awx/lib/site-packages/rest_framework/fields.py b/awx/lib/site-packages/rest_framework/fields.py
index 535aa2ac8e..add9d224d3 100644
--- a/awx/lib/site-packages/rest_framework/fields.py
+++ b/awx/lib/site-packages/rest_framework/fields.py
@@ -7,25 +7,24 @@ from __future__ import unicode_literals
import copy
import datetime
-from decimal import Decimal, DecimalException
import inspect
import re
import warnings
+from decimal import Decimal, DecimalException
+from django import forms
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
-from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
from rest_framework import ISO_8601
-from rest_framework.compat import (timezone, parse_date, parse_datetime,
- parse_time)
-from rest_framework.compat import BytesIO
-from rest_framework.compat import six
-from rest_framework.compat import smart_text, force_text, is_non_str_iterable
+from rest_framework.compat import (
+ timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text,
+ force_text, is_non_str_iterable
+)
from rest_framework.settings import api_settings
@@ -101,6 +100,19 @@ def humanize_strptime(format_string):
return format_string
+def strip_multiple_choice_msg(help_text):
+ """
+ Remove the 'Hold down "control" ...' message that is Django enforces in
+ select multiple fields on ModelForms. (Required for 1.5 and earlier)
+
+ See https://code.djangoproject.com/ticket/9321
+ """
+ multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.')
+ multiple_choice_msg = force_text(multiple_choice_msg)
+
+ return help_text.replace(multiple_choice_msg, '')
+
+
class Field(object):
read_only = True
creation_counter = 0
@@ -123,7 +135,7 @@ class Field(object):
self.label = smart_text(label)
if help_text is not None:
- self.help_text = smart_text(help_text)
+ self.help_text = strip_multiple_choice_msg(smart_text(help_text))
def initialize(self, parent, field_name):
"""
@@ -256,6 +268,12 @@ class WritableField(Field):
widget = widget()
self.widget = widget
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ result.validators = self.validators[:]
+ return result
+
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
@@ -331,9 +349,13 @@ class ModelField(WritableField):
raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length',
- getattr(self.model_field, 'min_length', None))
+ getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length',
- getattr(self.model_field, 'max_length', None))
+ getattr(self.model_field, 'max_length', None))
+ self.min_value = kwargs.pop('min_value',
+ getattr(self.model_field, 'min_value', None))
+ self.max_value = kwargs.pop('max_value',
+ getattr(self.model_field, 'max_value', None))
super(ModelField, self).__init__(*args, **kwargs)
@@ -341,6 +363,10 @@ class ModelField(WritableField):
self.validators.append(validators.MinLengthValidator(self.min_length))
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
+ if self.min_value is not None:
+ self.validators.append(validators.MinValueValidator(self.min_value))
+ if self.max_value is not None:
+ self.validators.append(validators.MaxValueValidator(self.max_value))
def from_native(self, value):
rel = getattr(self.model_field, "rel", None)
@@ -428,13 +454,6 @@ class SlugField(CharField):
def __init__(self, *args, **kwargs):
super(SlugField, self).__init__(*args, **kwargs)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- #result.widget = copy.deepcopy(self.widget, memo)
- result.validators = self.validators[:]
- return result
-
class ChoiceField(WritableField):
type_name = 'ChoiceField'
@@ -493,7 +512,7 @@ class EmailField(CharField):
form_field_class = forms.EmailField
default_error_messages = {
- 'invalid': _('Enter a valid e-mail address.'),
+ 'invalid': _('Enter a valid email address.'),
}
default_validators = [validators.validate_email]
@@ -503,13 +522,6 @@ class EmailField(CharField):
return None
return ret.strip()
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- #result.widget = copy.deepcopy(self.widget, memo)
- result.validators = self.validators[:]
- return result
-
class RegexField(CharField):
type_name = 'RegexField'
@@ -534,12 +546,6 @@ class RegexField(CharField):
regex = property(_get_regex, _set_regex)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- result.validators = self.validators[:]
- return result
-
class DateField(WritableField):
type_name = 'DateField'
@@ -918,7 +924,7 @@ class ImageField(FileField):
if f is None:
return None
- from compat import Image
+ from rest_framework.compat import Image
assert Image is not None, 'PIL must be installed for ImageField support'
# We need to get a file object for PIL. We might have a path or we might
diff --git a/awx/lib/site-packages/rest_framework/filters.py b/awx/lib/site-packages/rest_framework/filters.py
index c058bc715e..4079e1bd55 100644
--- a/awx/lib/site-packages/rest_framework/filters.py
+++ b/awx/lib/site-packages/rest_framework/filters.py
@@ -109,8 +109,7 @@ class OrderingFilter(BaseFilterBackend):
def get_ordering(self, request):
"""
- Search terms are set by a ?search=... query parameter,
- and may be comma and/or whitespace delimited.
+ Ordering is set by a comma delimited ?ordering=... query parameter.
"""
params = request.QUERY_PARAMS.get(self.ordering_param)
if params:
@@ -134,7 +133,7 @@ class OrderingFilter(BaseFilterBackend):
ordering = self.remove_invalid_fields(queryset, ordering)
if not ordering:
- # Use 'ordering' attribtue by default
+ # Use 'ordering' attribute by default
ordering = self.get_default_ordering(view)
if ordering:
diff --git a/awx/lib/site-packages/rest_framework/generics.py b/awx/lib/site-packages/rest_framework/generics.py
index 9ccc789805..99e9782e21 100644
--- a/awx/lib/site-packages/rest_framework/generics.py
+++ b/awx/lib/site-packages/rest_framework/generics.py
@@ -212,7 +212,7 @@ class GenericAPIView(views.APIView):
You may want to override this if you need to provide different
serializations depending on the incoming request.
- (Eg. admins get full serialization, others get basic serilization)
+ (Eg. admins get full serialization, others get basic serialization)
"""
serializer_class = self.serializer_class
if serializer_class is not None:
@@ -285,7 +285,7 @@ class GenericAPIView(views.APIView):
)
filter_kwargs = {self.slug_field: slug}
else:
- raise exceptions.ConfigurationError(
+ raise ImproperlyConfigured(
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
diff --git a/awx/lib/site-packages/rest_framework/parsers.py b/awx/lib/site-packages/rest_framework/parsers.py
index 25be2e6abc..96bfac84a0 100644
--- a/awx/lib/site-packages/rest_framework/parsers.py
+++ b/awx/lib/site-packages/rest_framework/parsers.py
@@ -50,10 +50,7 @@ class JSONParser(BaseParser):
def parse(self, stream, media_type=None, parser_context=None):
"""
- Returns a 2-tuple of `(data, files)`.
-
- `data` will be an object which is the parsed content of the response.
- `files` will always be `None`.
+ Parses the incoming bytestream as JSON and returns the resulting data.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
@@ -74,10 +71,7 @@ class YAMLParser(BaseParser):
def parse(self, stream, media_type=None, parser_context=None):
"""
- Returns a 2-tuple of `(data, files)`.
-
- `data` will be an object which is the parsed content of the response.
- `files` will always be `None`.
+ Parses the incoming bytestream as YAML and returns the resulting data.
"""
assert yaml, 'YAMLParser requires pyyaml to be installed'
@@ -100,10 +94,8 @@ class FormParser(BaseParser):
def parse(self, stream, media_type=None, parser_context=None):
"""
- Returns a 2-tuple of `(data, files)`.
-
- `data` will be a :class:`QueryDict` containing all the form parameters.
- `files` will always be :const:`None`.
+ Parses the incoming bytestream as a URL encoded form,
+ and returns the resulting QueryDict.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
@@ -120,7 +112,8 @@ class MultiPartParser(BaseParser):
def parse(self, stream, media_type=None, parser_context=None):
"""
- Returns a DataAndFiles object.
+ Parses the incoming bytestream as a multipart encoded form,
+ and returns a DataAndFiles object.
`.data` will be a `QueryDict` containing all the form parameters.
`.files` will be a `QueryDict` containing all the form files.
@@ -147,6 +140,9 @@ class XMLParser(BaseParser):
media_type = 'application/xml'
def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Parses the incoming bytestream as XML and returns the resulting data.
+ """
assert etree, 'XMLParser requires defusedxml to be installed'
parser_context = parser_context or {}
@@ -216,7 +212,8 @@ class FileUploadParser(BaseParser):
def parse(self, stream, media_type=None, parser_context=None):
"""
- Returns a DataAndFiles object.
+ Treats the incoming bytestream as a raw file upload and returns
+ a `DateAndFiles` object.
`.data` will be None (we expect request body to be a file content).
`.files` will be a `QueryDict` containing one 'file' element.
diff --git a/awx/lib/site-packages/rest_framework/permissions.py b/awx/lib/site-packages/rest_framework/permissions.py
index 45fcfd6658..1036663e05 100644
--- a/awx/lib/site-packages/rest_framework/permissions.py
+++ b/awx/lib/site-packages/rest_framework/permissions.py
@@ -128,7 +128,7 @@ class DjangoModelPermissions(BasePermission):
# Workaround to ensure DjangoModelPermissions are not applied
# to the root view when using DefaultRouter.
- if model_cls is None and getattr(view, '_ignore_model_permissions'):
+ if model_cls is None and getattr(view, '_ignore_model_permissions', False):
return True
assert model_cls, ('Cannot apply DjangoModelPermissions on a view that'
diff --git a/awx/lib/site-packages/rest_framework/relations.py b/awx/lib/site-packages/rest_framework/relations.py
index e3675b5124..edaf76d6ee 100644
--- a/awx/lib/site-packages/rest_framework/relations.py
+++ b/awx/lib/site-packages/rest_framework/relations.py
@@ -12,7 +12,7 @@ from django.db.models.fields import BLANK_CHOICE_DASH
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField, get_component
+from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
from rest_framework.reverse import reverse
from rest_framework.compat import urlparse
from rest_framework.compat import smart_text
@@ -144,7 +144,12 @@ class RelatedField(WritableField):
return None
if self.many:
- return [self.to_native(item) for item in value.all()]
+ if is_simple_callable(getattr(value, 'all', None)):
+ return [self.to_native(item) for item in value.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item) for item in value]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
@@ -242,7 +247,12 @@ class PrimaryKeyRelatedField(RelatedField):
queryset = get_component(queryset, component)
# Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
+ if is_simple_callable(getattr(queryset, 'all', None)):
+ return [self.to_native(item.pk) for item in queryset.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item.pk) for item in queryset]
# To-one relationship
try:
diff --git a/awx/lib/site-packages/rest_framework/renderers.py b/awx/lib/site-packages/rest_framework/renderers.py
index b2fe43eac2..3a03ca3324 100644
--- a/awx/lib/site-packages/rest_framework/renderers.py
+++ b/awx/lib/site-packages/rest_framework/renderers.py
@@ -11,14 +11,15 @@ from __future__ import unicode_literals
import copy
import json
from django import forms
+from django.core.exceptions import ImproperlyConfigured
from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template
+from django.test.client import encode_multipart
from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
from rest_framework.compat import yaml
-from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
from rest_framework.request import clone_request
from rest_framework.utils import encoders
@@ -270,7 +271,7 @@ class TemplateHTMLRenderer(BaseRenderer):
return [self.template_name]
elif hasattr(view, 'get_template_names'):
return view.get_template_names()
- raise ConfigurationError('Returned a template response with no template_name')
+ raise ImproperlyConfigured('Returned a template response with no template_name')
def get_exception_template(self, response):
template_names = [name % {'status_code': response.status_code}
@@ -571,3 +572,13 @@ class BrowsableAPIRenderer(BaseRenderer):
response.status_code = status.HTTP_200_OK
return ret
+
+
+class MultiPartRenderer(BaseRenderer):
+ media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
+ format = 'multipart'
+ charset = 'utf-8'
+ BOUNDARY = 'BoUnDaRyStRiNg'
+
+ def render(self, data, accepted_media_type=None, renderer_context=None):
+ return encode_multipart(self.BOUNDARY, data)
diff --git a/awx/lib/site-packages/rest_framework/request.py b/awx/lib/site-packages/rest_framework/request.py
index 0d88ebc7e4..919716f49a 100644
--- a/awx/lib/site-packages/rest_framework/request.py
+++ b/awx/lib/site-packages/rest_framework/request.py
@@ -64,6 +64,20 @@ def clone_request(request, method):
return ret
+class ForcedAuthentication(object):
+ """
+ This authentication class is used if the test client or request factory
+ forcibly authenticated the request.
+ """
+
+ def __init__(self, force_user, force_token):
+ self.force_user = force_user
+ self.force_token = force_token
+
+ def authenticate(self, request):
+ return (self.force_user, self.force_token)
+
+
class Request(object):
"""
Wrapper allowing to enhance a standard `HttpRequest` instance.
@@ -98,6 +112,12 @@ class Request(object):
self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
+ force_user = getattr(request, '_force_auth_user', None)
+ force_token = getattr(request, '_force_auth_token', None)
+ if (force_user is not None or force_token is not None):
+ forced_auth = ForcedAuthentication(force_user, force_token)
+ self.authenticators = (forced_auth,)
+
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
diff --git a/awx/lib/site-packages/rest_framework/response.py b/awx/lib/site-packages/rest_framework/response.py
index 3ee52ae01f..5877c8a3e9 100644
--- a/awx/lib/site-packages/rest_framework/response.py
+++ b/awx/lib/site-packages/rest_framework/response.py
@@ -12,7 +12,7 @@ from rest_framework.compat import six
class Response(SimpleTemplateResponse):
"""
- An HttpResponse that allows it's data to be rendered into
+ An HttpResponse that allows its data to be rendered into
arbitrary media types.
"""
diff --git a/awx/lib/site-packages/rest_framework/routers.py b/awx/lib/site-packages/rest_framework/routers.py
index 6c5fd00483..930011d39e 100644
--- a/awx/lib/site-packages/rest_framework/routers.py
+++ b/awx/lib/site-packages/rest_framework/routers.py
@@ -15,10 +15,11 @@ For example, you might have a `urls.py` that looks something like this:
"""
from __future__ import unicode_literals
+import itertools
from collections import namedtuple
+from django.core.exceptions import ImproperlyConfigured
from rest_framework import views
from rest_framework.compat import patterns, url
-from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns
@@ -39,6 +40,13 @@ def replace_methodname(format_string, methodname):
return ret
+def flatten(list_of_lists):
+ """
+ Takes an iterable of iterables, returns a single iterable containing all items
+ """
+ return itertools.chain(*list_of_lists)
+
+
class BaseRouter(object):
def __init__(self):
self.registry = []
@@ -72,7 +80,7 @@ class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
- url=r'^{prefix}/$',
+ url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
'post': 'create'
@@ -82,7 +90,7 @@ class SimpleRouter(BaseRouter):
),
# Detail route.
Route(
- url=r'^{prefix}/{lookup}/$',
+ url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
'put': 'update',
@@ -95,7 +103,7 @@ class SimpleRouter(BaseRouter):
# Dynamically generated routes.
# Generated using @action or @link decorators on methods of the viewset.
Route(
- url=r'^{prefix}/{lookup}/{methodname}/$',
+ url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
mapping={
'{httpmethod}': '{methodname}',
},
@@ -104,6 +112,10 @@ class SimpleRouter(BaseRouter):
),
]
+ def __init__(self, trailing_slash=True):
+ self.trailing_slash = trailing_slash and '/' or ''
+ super(SimpleRouter, self).__init__()
+
def get_default_base_name(self, viewset):
"""
If `base_name` is not specified, attempt to automatically determine
@@ -114,7 +126,7 @@ class SimpleRouter(BaseRouter):
if model_cls is None and queryset is not None:
model_cls = queryset.model
- assert model_cls, '`name` not argument not specified, and could ' \
+ assert model_cls, '`base_name` argument not specified, and could ' \
'not automatically determine the name from the viewset, as ' \
'it does not have a `.model` or `.queryset` attribute.'
@@ -127,12 +139,18 @@ class SimpleRouter(BaseRouter):
Returns a list of the Route namedtuple.
"""
+ known_actions = flatten([route.mapping.values() for route in self.routes])
+
# Determine any `@action` or `@link` decorated methods on the viewset
dynamic_routes = []
for methodname in dir(viewset):
attr = getattr(viewset, methodname)
httpmethods = getattr(attr, 'bind_to_methods', None)
if httpmethods:
+ if methodname in known_actions:
+ raise ImproperlyConfigured('Cannot use @action or @link decorator on '
+ 'method "%s" as it is an existing route' % methodname)
+ httpmethods = [method.lower() for method in httpmethods]
dynamic_routes.append((httpmethods, methodname))
ret = []
@@ -193,7 +211,11 @@ class SimpleRouter(BaseRouter):
continue
# Build the url pattern
- regex = route.url.format(prefix=prefix, lookup=lookup)
+ regex = route.url.format(
+ prefix=prefix,
+ lookup=lookup,
+ trailing_slash=self.trailing_slash
+ )
view = viewset.as_view(mapping, **route.initkwargs)
name = route.name.format(basename=basename)
ret.append(url(regex, view, name=name))
@@ -208,6 +230,7 @@ class DefaultRouter(SimpleRouter):
"""
include_root_view = True
include_format_suffixes = True
+ root_view_name = 'api-root'
def get_api_root_view(self):
"""
@@ -237,7 +260,7 @@ class DefaultRouter(SimpleRouter):
urls = []
if self.include_root_view:
- root_url = url(r'^$', self.get_api_root_view(), name='api-root')
+ root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
urls.append(root_url)
default_urls = super(DefaultRouter, self).get_urls()
diff --git a/awx/lib/site-packages/rest_framework/runtests/settings.py b/awx/lib/site-packages/rest_framework/runtests/settings.py
index 9dd7b545e6..b3702d0bfa 100644
--- a/awx/lib/site-packages/rest_framework/runtests/settings.py
+++ b/awx/lib/site-packages/rest_framework/runtests/settings.py
@@ -134,6 +134,8 @@ PASSWORD_HASHERS = (
'django.contrib.auth.hashers.CryptPasswordHasher',
)
+AUTH_USER_MODEL = 'auth.User'
+
import django
if django.VERSION < (1, 3):
diff --git a/awx/lib/site-packages/rest_framework/serializers.py b/awx/lib/site-packages/rest_framework/serializers.py
index 11ead02e4f..31cfa34474 100644
--- a/awx/lib/site-packages/rest_framework/serializers.py
+++ b/awx/lib/site-packages/rest_framework/serializers.py
@@ -683,14 +683,14 @@ class ModelSerializer(Serializer):
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name not in self.base_fields.keys(), \
- "field '%s' on serializer '%s' specfied in " \
+ "field '%s' on serializer '%s' specified in " \
"`read_only_fields`, but also added " \
- "as an explict field. Remove it from `read_only_fields`." % \
+ "as an explicit field. Remove it from `read_only_fields`." % \
(field_name, self.__class__.__name__)
assert field_name in ret, \
- "Noexistant field '%s' specified in `read_only_fields` " \
+ "Non-existant field '%s' specified in `read_only_fields` " \
"on serializer '%s'." % \
- (self.__class__.__name__, field_name)
+ (field_name, self.__class__.__name__)
ret[field_name].read_only = True
return ret
@@ -904,34 +904,23 @@ class HyperlinkedModelSerializer(ModelSerializer):
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
- # Just a placeholder to ensure 'url' is the first field
- # The field itself is actually created on initialization,
- # when the view_name and lookup_field arguments are available.
- url = Field()
-
- def __init__(self, *args, **kwargs):
- super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
+ def get_default_fields(self):
+ fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
- url_field = HyperlinkedIdentityField(
- view_name=self.opts.view_name,
- lookup_field=self.opts.lookup_field
- )
- url_field.initialize(self, 'url')
- self.fields['url'] = url_field
+ if 'url' not in fields:
+ url_field = HyperlinkedIdentityField(
+ view_name=self.opts.view_name,
+ lookup_field=self.opts.lookup_field
+ )
+ ret = self._dict_class()
+ ret['url'] = url_field
+ ret.update(fields)
+ fields = ret
- def _get_default_view_name(self, model):
- """
- Return the view name to use if 'view_name' is not specified in 'Meta'
- """
- model_meta = model._meta
- format_kwargs = {
- 'app_label': model_meta.app_label,
- 'model_name': model_meta.object_name.lower()
- }
- return self._default_view_name % format_kwargs
+ return fields
def get_pk_field(self, model_field):
if self.opts.fields and model_field.name in self.opts.fields:
@@ -966,3 +955,14 @@ class HyperlinkedModelSerializer(ModelSerializer):
return data.get('url', None)
except AttributeError:
return None
+
+ def _get_default_view_name(self, model):
+ """
+ Return the view name to use if 'view_name' is not specified in 'Meta'
+ """
+ model_meta = model._meta
+ format_kwargs = {
+ 'app_label': model_meta.app_label,
+ 'model_name': model_meta.object_name.lower()
+ }
+ return self._default_view_name % format_kwargs
diff --git a/awx/lib/site-packages/rest_framework/settings.py b/awx/lib/site-packages/rest_framework/settings.py
index beb511aca2..8fd177d586 100644
--- a/awx/lib/site-packages/rest_framework/settings.py
+++ b/awx/lib/site-packages/rest_framework/settings.py
@@ -73,6 +73,13 @@ DEFAULTS = {
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # Testing
+ 'TEST_REQUEST_RENDERER_CLASSES': (
+ 'rest_framework.renderers.MultiPartRenderer',
+ 'rest_framework.renderers.JSONRenderer'
+ ),
+ 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
+
# Browser enhancements
'FORM_METHOD_OVERRIDE': '_method',
'FORM_CONTENT_OVERRIDE': '_content',
@@ -115,6 +122,7 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
'FILTER_BACKEND',
+ 'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
diff --git a/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html b/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html
index 9d939e738b..51f9c2916b 100644
--- a/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html
+++ b/awx/lib/site-packages/rest_framework/templates/rest_framework/base.html
@@ -196,7 +196,7 @@
<button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
{% endif %}
{% if raw_data_patch_form %}
- <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PUT request on the {{ name }} resource">PATCH</button>
+ <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PATCH request on the {{ name }} resource">PATCH</button>
{% endif %}
</div>
</fieldset>
diff --git a/awx/lib/site-packages/rest_framework/test.py b/awx/lib/site-packages/rest_framework/test.py
new file mode 100644
index 0000000000..a18f5a2938
--- /dev/null
+++ b/awx/lib/site-packages/rest_framework/test.py
@@ -0,0 +1,157 @@
+# -- coding: utf-8 --
+
+# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
+# to make it harder for the user to import the wrong thing without realizing.
+from __future__ import unicode_literals
+import django
+from django.conf import settings
+from django.test.client import Client as DjangoClient
+from django.test.client import ClientHandler
+from django.test import testcases
+from rest_framework.settings import api_settings
+from rest_framework.compat import RequestFactory as DjangoRequestFactory
+from rest_framework.compat import force_bytes_or_smart_bytes, six
+
+
+def force_authenticate(request, user=None, token=None):
+ request._force_auth_user = user
+ request._force_auth_token = token
+
+
+class APIRequestFactory(DjangoRequestFactory):
+ renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
+ default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
+
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ self.enforce_csrf_checks = enforce_csrf_checks
+ self.renderer_classes = {}
+ for cls in self.renderer_classes_list:
+ self.renderer_classes[cls.format] = cls
+ super(APIRequestFactory, self).__init__(**defaults)
+
+ def _encode_data(self, data, format=None, content_type=None):
+ """
+ Encode the data returning a two tuple of (bytes, content_type)
+ """
+
+ if not data:
+ return ('', None)
+
+ assert format is None or content_type is None, (
+ 'You may not set both `format` and `content_type`.'
+ )
+
+ if content_type:
+ # Content type specified explicitly, treat data as a raw bytestring
+ ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
+
+ else:
+ format = format or self.default_format
+
+ assert format in self.renderer_classes, ("Invalid format '{0}'. "
+ "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
+ "to enable extra request formats.".format(
+ format,
+ ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
+ )
+ )
+
+ # Use format and render the data into a bytestring
+ renderer = self.renderer_classes[format]()
+ ret = renderer.render(data)
+
+ # Determine the content-type header from the renderer
+ content_type = "{0}; charset={1}".format(
+ renderer.media_type, renderer.charset
+ )
+
+ # Coerce text to bytes if required.
+ if isinstance(ret, six.text_type):
+ ret = bytes(ret.encode(renderer.charset))
+
+ return ret, content_type
+
+ def post(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('POST', path, data, content_type, **extra)
+
+ def put(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('PUT', path, data, content_type, **extra)
+
+ def patch(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('PATCH', path, data, content_type, **extra)
+
+ def delete(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('DELETE', path, data, content_type, **extra)
+
+ def options(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('OPTIONS', path, data, content_type, **extra)
+
+ def request(self, **kwargs):
+ request = super(APIRequestFactory, self).request(**kwargs)
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
+ return request
+
+
+class ForceAuthClientHandler(ClientHandler):
+ """
+ A patched version of ClientHandler that can enforce authentication
+ on the outgoing requests.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self._force_user = None
+ self._force_token = None
+ super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
+
+ def get_response(self, request):
+ # This is the simplest place we can hook into to patch the
+ # request object.
+ force_authenticate(request, self._force_user, self._force_token)
+ return super(ForceAuthClientHandler, self).get_response(request)
+
+
+class APIClient(APIRequestFactory, DjangoClient):
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ super(APIClient, self).__init__(**defaults)
+ self.handler = ForceAuthClientHandler(enforce_csrf_checks)
+ self._credentials = {}
+
+ def credentials(self, **kwargs):
+ """
+ Sets headers that will be used on every outgoing request.
+ """
+ self._credentials = kwargs
+
+ def force_authenticate(self, user=None, token=None):
+ """
+ Forcibly authenticates outgoing requests with the given
+ user and/or token.
+ """
+ self.handler._force_user = user
+ self.handler._force_token = token
+
+ def request(self, **kwargs):
+ # Ensure that any credentials set get added to every request.
+ kwargs.update(self._credentials)
+ return super(APIClient, self).request(**kwargs)
+
+
+class APITransactionTestCase(testcases.TransactionTestCase):
+ client_class = APIClient
+
+
+class APITestCase(testcases.TestCase):
+ client_class = APIClient
+
+
+if django.VERSION >= (1, 4):
+ class APISimpleTestCase(testcases.SimpleTestCase):
+ client_class = APIClient
+
+ class APILiveServerTestCase(testcases.LiveServerTestCase):
+ client_class = APIClient
diff --git a/awx/lib/site-packages/rest_framework/tests/description.py b/awx/lib/site-packages/rest_framework/tests/description.py
new file mode 100644
index 0000000000..b46d7f54d8
--- /dev/null
+++ b/awx/lib/site-packages/rest_framework/tests/description.py
@@ -0,0 +1,26 @@
+# -- coding: utf-8 --
+
+# Apparently there is a python 2.6 issue where docstrings of imported view classes
+# do not retain their encoding information even if a module has a proper
+# encoding declaration at the top of its source file. Therefore for tests
+# to catch unicode related errors, a mock view has to be declared in a separate
+# module.
+
+from rest_framework.views import APIView
+
+
+# test strings snatched from http://www.columbia.edu/~fdc/utf8/,
+# http://winrus.com/utf8-jap.htm and memory
+UTF8_TEST_DOCSTRING = (
+ 'zażółć gęślą jaźń'
+ 'Sîne klâwen durh die wolken sint geslagen'
+ 'Τη γλώσσα μου έδωσαν ελληνική'
+ 'யாமறிந்த மொழிகளிலே தமிழ்மொழி'
+ 'На берегу пустынных волн'
+ 'てすと'
+ 'アイウエオカキクケコサシスセソタチツテ'
+)
+
+
+class ViewWithNonASCIICharactersInDocstring(APIView):
+ __doc__ = UTF8_TEST_DOCSTRING
diff --git a/awx/lib/site-packages/rest_framework/tests/models.py b/awx/lib/site-packages/rest_framework/tests/models.py
index e2d4eacdc9..1598ecd94a 100644
--- a/awx/lib/site-packages/rest_framework/tests/models.py
+++ b/awx/lib/site-packages/rest_framework/tests/models.py
@@ -52,7 +52,7 @@ class CallableDefaultValueModel(RESTFrameworkModel):
class ManyToManyModel(RESTFrameworkModel):
- rel = models.ManyToManyField(Anchor)
+ rel = models.ManyToManyField(Anchor, help_text='Some help text.')
class ReadOnlyManyToManyModel(RESTFrameworkModel):
diff --git a/awx/lib/site-packages/rest_framework/tests/test_authentication.py b/awx/lib/site-packages/rest_framework/tests/test_authentication.py
index d46ac07985..a44813b691 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_authentication.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_authentication.py
@@ -1,7 +1,7 @@
from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.http import HttpResponse
-from django.test import Client, TestCase
+from django.test import TestCase
from django.utils import unittest
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
@@ -21,14 +21,13 @@ from rest_framework.authtoken.models import Token
from rest_framework.compat import patterns, url, include
from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider
-from rest_framework.tests.utils import RequestFactory
+from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
-import json
import base64
import time
import datetime
-factory = RequestFactory()
+factory = APIRequestFactory()
class MockView(APIView):
@@ -68,7 +67,7 @@ class BasicAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication'
def setUp(self):
- self.csrf_client = Client(enforce_csrf_checks=True)
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
@@ -87,7 +86,7 @@ class BasicAuthTests(TestCase):
credentials = ('%s:%s' % (self.username, self.password))
base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials
- response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_basic_auth(self):
@@ -97,7 +96,7 @@ class BasicAuthTests(TestCase):
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
@@ -107,8 +106,8 @@ class SessionAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication'
def setUp(self):
- self.csrf_client = Client(enforce_csrf_checks=True)
- self.non_csrf_client = Client(enforce_csrf_checks=False)
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.non_csrf_client = APIClient(enforce_csrf_checks=False)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
@@ -154,7 +153,7 @@ class TokenAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication'
def setUp(self):
- self.csrf_client = Client(enforce_csrf_checks=True)
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
@@ -172,7 +171,7 @@ class TokenAuthTests(TestCase):
def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_token_auth(self):
@@ -182,7 +181,7 @@ class TokenAuthTests(TestCase):
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_has_auto_assigned_key_if_none_provided(self):
@@ -193,33 +192,33 @@ class TokenAuthTests(TestCase):
def test_token_login_json(self):
"""Ensure token login view using JSON POST works."""
- client = Client(enforce_csrf_checks=True)
+ client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/',
- json.dumps({'username': self.username, 'password': self.password}), 'application/json')
+ {'username': self.username, 'password': self.password}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
+ self.assertEqual(response.data['token'], self.key)
def test_token_login_json_bad_creds(self):
"""Ensure token login view using JSON POST fails if bad credentials are used."""
- client = Client(enforce_csrf_checks=True)
+ client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/',
- json.dumps({'username': self.username, 'password': "badpass"}), 'application/json')
+ {'username': self.username, 'password': "badpass"}, format='json')
self.assertEqual(response.status_code, 400)
def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields."""
- client = Client(enforce_csrf_checks=True)
+ client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/',
- json.dumps({'username': self.username}), 'application/json')
+ {'username': self.username}, format='json')
self.assertEqual(response.status_code, 400)
def test_token_login_form(self):
"""Ensure token login view using form POST works."""
- client = Client(enforce_csrf_checks=True)
+ client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/',
{'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
+ self.assertEqual(response.data['token'], self.key)
class IncorrectCredentialsTests(TestCase):
@@ -256,7 +255,7 @@ class OAuthTests(TestCase):
self.consts = consts
- self.csrf_client = Client(enforce_csrf_checks=True)
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
@@ -428,13 +427,55 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth-with-scope/', params)
self.assertEqual(response.status_code, 200)
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_consumer_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': 'badconsumerkey'
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_token_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': 'badtokenkey',
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
urls = 'rest_framework.tests.test_authentication'
def setUp(self):
- self.csrf_client = Client(enforce_csrf_checks=True)
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
diff --git a/awx/lib/site-packages/rest_framework/tests/test_decorators.py b/awx/lib/site-packages/rest_framework/tests/test_decorators.py
index 1016fed3ff..195f0ba3e4 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_decorators.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_decorators.py
@@ -1,12 +1,13 @@
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import status
+from rest_framework.authentication import BasicAuthentication
+from rest_framework.parsers import JSONParser
+from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.renderers import JSONRenderer
-from rest_framework.parsers import JSONParser
-from rest_framework.authentication import BasicAuthentication
+from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle
-from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
from rest_framework.decorators import (
api_view,
@@ -17,13 +18,11 @@ from rest_framework.decorators import (
permission_classes,
)
-from rest_framework.tests.utils import RequestFactory
-
class DecoratorTestCase(TestCase):
def setUp(self):
- self.factory = RequestFactory()
+ self.factory = APIRequestFactory()
def _finalize_response(self, request, response, *args, **kwargs):
response.request = request
diff --git a/awx/lib/site-packages/rest_framework/tests/test_description.py b/awx/lib/site-packages/rest_framework/tests/test_description.py
index 52c1a34c10..8019f5ecaf 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_description.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_description.py
@@ -2,8 +2,10 @@
from __future__ import unicode_literals
from django.test import TestCase
+from rest_framework.compat import apply_markdown, smart_text
from rest_framework.views import APIView
-from rest_framework.compat import apply_markdown
+from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
+from rest_framework.tests.description import UTF8_TEST_DOCSTRING
from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
@@ -83,11 +85,10 @@ class TestViewNamesAndDescriptions(TestCase):
Unicode in docstrings should be respected.
"""
- class MockView(APIView):
- """Проверка"""
- pass
-
- self.assertEqual(get_view_description(MockView), "Проверка")
+ self.assertEqual(
+ get_view_description(ViewWithNonASCIICharactersInDocstring),
+ smart_text(UTF8_TEST_DOCSTRING)
+ )
def test_view_description_can_be_empty(self):
"""
diff --git a/awx/lib/site-packages/rest_framework/tests/test_fields.py b/awx/lib/site-packages/rest_framework/tests/test_fields.py
index de3710011c..6836ec86f6 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_fields.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_fields.py
@@ -865,4 +865,34 @@ class FieldCallableDefault(TestCase):
field = serializers.WritableField(default=self.simple_callable)
into = {}
field.field_from_native({}, {}, 'field', into)
- self.assertEquals(into, {'field': 'foo bar'})
+ self.assertEqual(into, {'field': 'foo bar'})
+
+
+class CustomIntegerField(TestCase):
+ """
+ Test that custom fields apply min_value and max_value constraints
+ """
+ def test_custom_fields_can_be_validated_for_value(self):
+
+ class MoneyField(models.PositiveIntegerField):
+ pass
+
+ class EntryModel(models.Model):
+ bank = MoneyField(validators=[validators.MaxValueValidator(100)])
+
+ class EntrySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = EntryModel
+
+ entry = EntryModel(bank=1)
+
+ serializer = EntrySerializer(entry, data={"bank": 11})
+ self.assertTrue(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": -1})
+ self.assertFalse(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": 101})
+ self.assertFalse(serializer.is_valid())
+
+
diff --git a/awx/lib/site-packages/rest_framework/tests/test_filters.py b/awx/lib/site-packages/rest_framework/tests/test_filters.py
index aaed624782..c9d9e7ffaa 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_filters.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_filters.py
@@ -4,13 +4,13 @@ from decimal import Decimal
from django.db import models
from django.core.urlresolvers import reverse
from django.test import TestCase
-from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
+from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
-factory = RequestFactory()
+factory = APIRequestFactory()
class FilterableItem(models.Model):
diff --git a/awx/lib/site-packages/rest_framework/tests/test_generics.py b/awx/lib/site-packages/rest_framework/tests/test_generics.py
index 37734195aa..1550880b56 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_generics.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_generics.py
@@ -3,12 +3,11 @@ from django.db import models
from django.shortcuts import get_object_or_404
from django.test import TestCase
from rest_framework import generics, renderers, serializers, status
-from rest_framework.tests.utils import RequestFactory
+from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six
-import json
-factory = RequestFactory()
+factory = APIRequestFactory()
class RootView(generics.ListCreateAPIView):
@@ -71,9 +70,8 @@ class TestRootView(TestCase):
"""
POST requests to ListCreateAPIView should create a new object.
"""
- content = {'text': 'foobar'}
- request = factory.post('/', json.dumps(content),
- content_type='application/json')
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -85,9 +83,8 @@ class TestRootView(TestCase):
"""
PUT requests to ListCreateAPIView should not be allowed
"""
- content = {'text': 'foobar'}
- request = factory.put('/', json.dumps(content),
- content_type='application/json')
+ data = {'text': 'foobar'}
+ request = factory.put('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
@@ -148,9 +145,8 @@ class TestRootView(TestCase):
"""
POST requests to create a new object should not be able to set the id.
"""
- content = {'id': 999, 'text': 'foobar'}
- request = factory.post('/', json.dumps(content),
- content_type='application/json')
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -189,9 +185,8 @@ class TestInstanceView(TestCase):
"""
POST requests to RetrieveUpdateDestroyAPIView should not be allowed
"""
- content = {'text': 'foobar'}
- request = factory.post('/', json.dumps(content),
- content_type='application/json')
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
@@ -201,9 +196,8 @@ class TestInstanceView(TestCase):
"""
PUT requests to RetrieveUpdateDestroyAPIView should update an object.
"""
- content = {'text': 'foobar'}
- request = factory.put('/1', json.dumps(content),
- content_type='application/json')
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
with self.assertNumQueries(2):
response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -215,9 +209,8 @@ class TestInstanceView(TestCase):
"""
PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
"""
- content = {'text': 'foobar'}
- request = factory.patch('/1', json.dumps(content),
- content_type='application/json')
+ data = {'text': 'foobar'}
+ request = factory.patch('/1', data, format='json')
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
@@ -293,9 +286,8 @@ class TestInstanceView(TestCase):
"""
PUT requests to create a new object should not be able to set the id.
"""
- content = {'id': 999, 'text': 'foobar'}
- request = factory.put('/1', json.dumps(content),
- content_type='application/json')
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -309,9 +301,8 @@ class TestInstanceView(TestCase):
if it does not currently exist.
"""
self.objects.get(id=1).delete()
- content = {'text': 'foobar'}
- request = factory.put('/1', json.dumps(content),
- content_type='application/json')
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
with self.assertNumQueries(3):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -324,10 +315,9 @@ class TestInstanceView(TestCase):
PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if it doesn't exist.
"""
- content = {'text': 'foobar'}
+ data = {'text': 'foobar'}
# pk fields can not be created on demand, only the database can set the pk for a new object
- request = factory.put('/5', json.dumps(content),
- content_type='application/json')
+ request = factory.put('/5', data, format='json')
with self.assertNumQueries(3):
response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -339,9 +329,8 @@ class TestInstanceView(TestCase):
PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
"""
- content = {'text': 'foobar'}
- request = factory.put('/test_slug', json.dumps(content),
- content_type='application/json')
+ data = {'text': 'foobar'}
+ request = factory.put('/test_slug', data, format='json')
with self.assertNumQueries(2):
response = self.slug_based_view(request, slug='test_slug').render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase):
https://github.com/tomchristie/django-rest-framework/issues/285
"""
- content = {'email': 'foobar@example.com', 'content': 'foobar'}
- request = factory.post('/', json.dumps(content),
- content_type='application/json')
+ data = {'email': 'foobar@example.com', 'content': 'foobar'}
+ request = factory.post('/', data, format='json')
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1)
diff --git a/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py b/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py
index 1894ddb27e..61e613d75e 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_hyperlinkedserializers.py
@@ -1,12 +1,15 @@
from __future__ import unicode_literals
import json
from django.test import TestCase
-from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
from rest_framework.compat import patterns, url
-from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
+from rest_framework.test import APIRequestFactory
+from rest_framework.tests.models import (
+ Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
+ Album, Photo, OptionalRelationModel
+)
-factory = RequestFactory()
+factory = APIRequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer):
@@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer):
class PhotoSerializer(serializers.Serializer):
description = serializers.CharField()
- album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title')
+ album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')
def restore_object(self, attrs, instance=None):
return Photo(**attrs)
@@ -301,3 +304,30 @@ class TestOptionalRelationHyperlinkedView(TestCase):
data=json.dumps(self.data),
content_type='application/json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class TestOverriddenURLField(TestCase):
+ def setUp(self):
+ class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
+ url = serializers.SerializerMethodField('get_url')
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', 'url')
+
+ def get_url(self, obj):
+ return 'foo bar'
+
+ self.Serializer = OverriddenURLSerializer
+ self.obj = BlogPost.objects.create(title='New blog post')
+
+ def test_overridden_url_field(self):
+ """
+ The 'url' field should respect overriding.
+ Regression test for #936.
+ """
+ serializer = self.Serializer(self.obj)
+ self.assertEqual(
+ serializer.data,
+ {'title': 'New blog post', 'url': 'foo bar'}
+ )
diff --git a/awx/lib/site-packages/rest_framework/tests/test_negotiation.py b/awx/lib/site-packages/rest_framework/tests/test_negotiation.py
index 7f84827f0e..04b89eb600 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_negotiation.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_negotiation.py
@@ -1,12 +1,12 @@
from __future__ import unicode_literals
from django.test import TestCase
-from django.test.client import RequestFactory
from rest_framework.negotiation import DefaultContentNegotiation
from rest_framework.request import Request
from rest_framework.renderers import BaseRenderer
+from rest_framework.test import APIRequestFactory
-factory = RequestFactory()
+factory = APIRequestFactory()
class MockJSONRenderer(BaseRenderer):
diff --git a/awx/lib/site-packages/rest_framework/tests/test_pagination.py b/awx/lib/site-packages/rest_framework/tests/test_pagination.py
index e538a78e5b..85d4640ead 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_pagination.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_pagination.py
@@ -4,13 +4,13 @@ from decimal import Decimal
from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
-from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
+from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
-factory = RequestFactory()
+factory = APIRequestFactory()
class FilterableItem(models.Model):
@@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase):
self.page = paginator.page(1)
def test_custom_pagination_serializer(self):
- request = RequestFactory().get('/foobar')
+ request = APIRequestFactory().get('/foobar')
serializer = CustomPaginationSerializer(
instance=self.page,
context={'request': request}
diff --git a/awx/lib/site-packages/rest_framework/tests/test_permissions.py b/awx/lib/site-packages/rest_framework/tests/test_permissions.py
index 6caaf65b02..e2cca3808c 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_permissions.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_permissions.py
@@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission
from django.db import models
from django.test import TestCase
from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
-from rest_framework.tests.utils import RequestFactory
+from rest_framework.test import APIRequestFactory
import base64
-import json
-factory = RequestFactory()
+factory = APIRequestFactory()
class BasicModel(models.Model):
@@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase):
BasicModel(text='foo').save()
def test_has_create_permissions(self):
- request = factory.post('/', json.dumps({'text': 'foobar'}),
- content_type='application/json',
+ request = factory.post('/', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_has_put_permissions(self):
- request = factory.put('/1', json.dumps({'text': 'foobar'}),
- content_type='application/json',
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_does_not_have_create_permissions(self):
- request = factory.post('/', json.dumps({'text': 'foobar'}),
- content_type='application/json',
+ request = factory.post('/', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_put_permissions(self):
- request = factory.put('/1', json.dumps({'text': 'foobar'}),
- content_type='application/json',
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase):
def test_has_put_as_create_permissions(self):
# User only has update permissions - should be able to update an entity.
- request = factory.put('/1', json.dumps({'text': 'foobar'}),
- content_type='application/json',
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
# But if PUTing to a new entity, permission should be denied.
- request = factory.put('/2', json.dumps({'text': 'foobar'}),
- content_type='application/json',
+ request = factory.put('/2', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self):
- request = factory.options('/', content_type='application/json',
+ request = factory.options('/',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['POST'])
- request = factory.options('/1', content_type='application/json',
+ request = factory.options('/1',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_options_disallowed(self):
- request = factory.options('/', content_type='application/json',
+ request = factory.options('/',
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
- request = factory.options('/1', content_type='application/json',
+ request = factory.options('/1',
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
def test_options_updateonly(self):
- request = factory.options('/', content_type='application/json',
+ request = factory.options('/',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
- request = factory.options('/1', content_type='application/json',
+ request = factory.options('/1',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py b/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py
index 2ca7f4f2b3..3c4d39af63 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_relations_hyperlink.py
@@ -1,15 +1,15 @@
from __future__ import unicode_literals
from django.test import TestCase
-from django.test.client import RequestFactory
from rest_framework import serializers
from rest_framework.compat import patterns, url
+from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import (
BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
)
-factory = RequestFactory()
+factory = APIRequestFactory()
request = factory.get('/') # Just to ensure we have a request in the serializer context
diff --git a/awx/lib/site-packages/rest_framework/tests/test_renderers.py b/awx/lib/site-packages/rest_framework/tests/test_renderers.py
index 95b597411f..df6f4aa63f 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_renderers.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_renderers.py
@@ -4,19 +4,17 @@ from __future__ import unicode_literals
from decimal import Decimal
from django.core.cache import cache
from django.test import TestCase
-from django.test.client import RequestFactory
from django.utils import unittest
from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions
-from rest_framework.compat import yaml, etree, patterns, url, include
+from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
-from rest_framework.compat import StringIO
-from rest_framework.compat import six
+from rest_framework.test import APIRequestFactory
import datetime
import pickle
import re
@@ -121,7 +119,7 @@ class POSTDeniedView(APIView):
class DocumentingRendererTests(TestCase):
def test_only_permitted_forms_are_displayed(self):
view = POSTDeniedView.as_view()
- request = RequestFactory().get('/')
+ request = APIRequestFactory().get('/')
response = view(request).render()
self.assertNotContains(response, '>POST<')
self.assertContains(response, '>PUT<')
diff --git a/awx/lib/site-packages/rest_framework/tests/test_request.py b/awx/lib/site-packages/rest_framework/tests/test_request.py
index a5c5e84ce7..969d8024a9 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_request.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_request.py
@@ -5,8 +5,7 @@ from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
-from django.test import TestCase, Client
-from django.test.client import RequestFactory
+from django.test import TestCase
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import patterns
@@ -19,12 +18,13 @@ from rest_framework.parsers import (
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
from rest_framework.compat import six
import json
-factory = RequestFactory()
+factory = APIRequestFactory()
class PlainTextParser(BaseParser):
@@ -116,16 +116,7 @@ class TestContentParsing(TestCase):
Ensure request.DATA returns content for PUT request with form content.
"""
data = {'qwerty': 'uiop'}
-
- from django import VERSION
-
- if VERSION >= (1, 5):
- from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart
- request = Request(factory.put('/', encode_multipart(BOUNDARY, data),
- content_type=MULTIPART_CONTENT))
- else:
- request = Request(factory.put('/', data))
-
+ request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(list(request.DATA.items()), list(data.items()))
@@ -257,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase):
urls = 'rest_framework.tests.test_request'
def setUp(self):
- self.csrf_client = Client(enforce_csrf_checks=True)
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
diff --git a/awx/lib/site-packages/rest_framework/tests/test_reverse.py b/awx/lib/site-packages/rest_framework/tests/test_reverse.py
index 93ef563776..690a30b119 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_reverse.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_reverse.py
@@ -1,10 +1,10 @@
from __future__ import unicode_literals
from django.test import TestCase
-from django.test.client import RequestFactory
from rest_framework.compat import patterns, url
from rest_framework.reverse import reverse
+from rest_framework.test import APIRequestFactory
-factory = RequestFactory()
+factory = APIRequestFactory()
def null_view(request):
diff --git a/awx/lib/site-packages/rest_framework/tests/test_routers.py b/awx/lib/site-packages/rest_framework/tests/test_routers.py
index 10d3cc25a0..5fcccb7414 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_routers.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_routers.py
@@ -1,14 +1,15 @@
from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
-from django.test.client import RequestFactory
-from rest_framework import serializers, viewsets
+from django.core.exceptions import ImproperlyConfigured
+from rest_framework import serializers, viewsets, permissions
from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action
from rest_framework.response import Response
-from rest_framework.routers import SimpleRouter
+from rest_framework.routers import SimpleRouter, DefaultRouter
+from rest_framework.test import APIRequestFactory
-factory = RequestFactory()
+factory = APIRequestFactory()
urlpatterns = patterns('',)
@@ -50,7 +51,7 @@ class TestSimpleRouter(TestCase):
route = decorator_routes[i]
# check url listing
self.assertEqual(route.url,
- '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint))
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
# check method to function mapping
if endpoint == 'action3':
methods_map = ['post', 'delete']
@@ -103,7 +104,7 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_list_view(self):
response = self.client.get('/notes/')
- self.assertEquals(response.data,
+ self.assertEqual(response.data,
[{
"url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar"
@@ -112,10 +113,104 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_detail_view(self):
response = self.client.get('/notes/123/')
- self.assertEquals(response.data,
+ self.assertEqual(response.data,
{
"url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar"
}
)
+
+class TestTrailingSlashIncluded(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_have_trailing_slash_by_default(self):
+ expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlashRemoved(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter(trailing_slash=False)
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_can_have_trailing_slash_removed(self):
+ expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestNameableRoot(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+ self.router = DefaultRouter()
+ self.router.root_view_name = 'nameable-root'
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_router_has_custom_name(self):
+ expected = 'nameable-root'
+ self.assertEqual(expected, self.urls[0].name)
+
+
+class TestActionKeywordArgs(TestCase):
+ """
+ Ensure keyword arguments passed in the `@action` decorator
+ are properly handled. Refs #940.
+ """
+
+ def setUp(self):
+ class TestViewSet(viewsets.ModelViewSet):
+ permission_classes = []
+
+ @action(permission_classes=[permissions.AllowAny])
+ def custom(self, request, *args, **kwargs):
+ return Response({
+ 'permission_classes': self.permission_classes
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+ self.view = self.router.urls[-1].callback
+
+ def test_action_kwargs(self):
+ request = factory.post('/test/0/custom/')
+ response = self.view(request)
+ self.assertEqual(
+ response.data,
+ {'permission_classes': [permissions.AllowAny]}
+ )
+
+
+class TestActionAppliedToExistingRoute(TestCase):
+ """
+ Ensure `@action` decorator raises an except when applied
+ to an existing route
+ """
+
+ def test_exception_raised_when_action_applied_to_existing_route(self):
+ class TestViewSet(viewsets.ModelViewSet):
+
+ @action()
+ def retrieve(self, request, *args, **kwargs):
+ return Response({
+ 'hello': 'world'
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+
+ with self.assertRaises(ImproperlyConfigured):
+ self.router.urls
diff --git a/awx/lib/site-packages/rest_framework/tests/test_serializer.py b/awx/lib/site-packages/rest_framework/tests/test_serializer.py
index 6cc913c5cd..c24976603e 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_serializer.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_serializer.py
@@ -494,7 +494,7 @@ class CustomValidationTests(TestCase):
}
serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'email': ['Enter a valid e-mail address.']})
+ self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']})
class PositiveIntegerAsChoiceTests(TestCase):
@@ -1376,6 +1376,18 @@ class FieldLabelTest(TestCase):
self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label)
+# Test for issue #961
+
+class ManyFieldHelpTextTest(TestCase):
+ def test_help_text_no_hold_down_control_msg(self):
+ """
+ Validate that help_text doesn't contain the 'Hold down "Control" ...'
+ message that Django appends to choice fields.
+ """
+ rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text)
+ self.assertEqual('Some help text.', rel_field.help_text)
+
+
class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
def setUp(self):
@@ -1556,3 +1568,78 @@ class MetadataSerializerTestCase(TestCase):
}
}
self.assertEqual(expected, metadata)
+
+
+### Regression test for #840
+
+class SimpleModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimpleModelSerializer(serializers.ModelSerializer):
+ text = serializers.CharField()
+ other = serializers.CharField()
+
+ class Meta:
+ model = SimpleModel
+
+ def validate_other(self, attrs, source):
+ del attrs['other']
+ return attrs
+
+
+class FieldValidationRemovingAttr(TestCase):
+ def test_removing_non_model_field_in_validation(self):
+ """
+ Removing an attr during field valiation should ensure that it is not
+ passed through when restoring the object.
+
+ This allows additional non-model fields to be supported.
+
+ Regression test for #840.
+ """
+ serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.text, 'foo')
+
+
+### Regression test for #878
+
+class SimpleTargetModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimplePKSourceModelSerializer(serializers.Serializer):
+ targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
+ text = serializers.CharField()
+
+
+class SimpleSlugSourceModelSerializer(serializers.Serializer):
+ targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
+ text = serializers.CharField()
+
+
+class SerializerSupportsManyRelationships(TestCase):
+ def setUp(self):
+ SimpleTargetModel.objects.create(text='foo')
+ SimpleTargetModel.objects.create(text='bar')
+
+ def test_serializer_supports_pk_many_relationships(self):
+ """
+ Regression test for #878.
+
+ Note that pk behavior has a different code path to usual cases,
+ for performance reasons.
+ """
+ serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+ def test_serializer_supports_slug_many_relationships(self):
+ """
+ Regression test for #878.
+ """
+ serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
diff --git a/awx/lib/site-packages/rest_framework/tests/test_testing.py b/awx/lib/site-packages/rest_framework/tests/test_testing.py
new file mode 100644
index 0000000000..49d45fc292
--- /dev/null
+++ b/awx/lib/site-packages/rest_framework/tests/test_testing.py
@@ -0,0 +1,115 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
+from django.contrib.auth.models import User
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
+
+
+@api_view(['GET', 'POST'])
+def view(request):
+ return Response({
+ 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
+ 'user': request.user.username
+ })
+
+
+urlpatterns = patterns('',
+ url(r'^view/$', view),
+)
+
+
+class TestAPITestClient(TestCase):
+ urls = 'rest_framework.tests.test_testing'
+
+ def setUp(self):
+ self.client = APIClient()
+
+ def test_credentials(self):
+ """
+ Setting `.credentials()` adds the required headers to each request.
+ """
+ self.client.credentials(HTTP_AUTHORIZATION='example')
+ for _ in range(0, 3):
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], 'example')
+
+ def test_force_authenticate(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ User.objects.create_user('example', 'example@example.com', 'password')
+ self.client.login(username='example', password='password')
+ response = self.client.post('/view/')
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ client = APIClient(enforce_csrf_checks=True)
+ User.objects.create_user('example', 'example@example.com', 'password')
+ client.login(username='example', password='password')
+ response = client.post('/view/')
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+
+class TestAPIRequestFactory(TestCase):
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory()
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory(enforce_csrf_checks=True)
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+ def test_invalid_format(self):
+ """
+ Attempting to use a format that is not configured will raise an
+ assertion error.
+ """
+ factory = APIRequestFactory()
+ self.assertRaises(AssertionError, factory.post,
+ path='/view/', data={'example': 1}, format='xml'
+ )
+
+ def test_force_authenticate(self):
+ """
+ Setting `force_authenticate()` forcibly authenticates the request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ factory = APIRequestFactory()
+ request = factory.get('/view')
+ force_authenticate(request, user=user)
+ response = view(request)
+ self.assertEqual(response.data['user'], 'example')
diff --git a/awx/lib/site-packages/rest_framework/tests/test_throttling.py b/awx/lib/site-packages/rest_framework/tests/test_throttling.py
index da400b2fcd..41bff6926a 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_throttling.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_throttling.py
@@ -5,9 +5,9 @@ from __future__ import unicode_literals
from django.test import TestCase
from django.contrib.auth.models import User
from django.core.cache import cache
-from django.test.client import RequestFactory
+from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
-from rest_framework.throttling import UserRateThrottle
+from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
from rest_framework.response import Response
@@ -21,6 +21,14 @@ class User3MinRateThrottle(UserRateThrottle):
scope = 'minutes'
+class NonTimeThrottle(BaseThrottle):
+ def allow_request(self, request, view):
+ if not hasattr(self.__class__, 'called'):
+ self.__class__.called = True
+ return True
+ return False
+
+
class MockView(APIView):
throttle_classes = (User3SecRateThrottle,)
@@ -35,15 +43,20 @@ class MockView_MinuteThrottling(APIView):
return Response('foo')
-class ThrottlingTests(TestCase):
- urls = 'rest_framework.tests.test_throttling'
+class MockView_NonTimeThrottling(APIView):
+ throttle_classes = (NonTimeThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+class ThrottlingTests(TestCase):
def setUp(self):
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
- self.factory = RequestFactory()
+ self.factory = APIRequestFactory()
def test_requests_are_throttled(self):
"""
@@ -141,3 +154,124 @@ class ThrottlingTests(TestCase):
(60, None),
(80, None)
))
+
+ def test_non_time_throttle(self):
+ """
+ Ensure for second based throttles.
+ """
+ request = self.factory.get('/')
+
+ self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+ self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+
+class ScopedRateThrottleTests(TestCase):
+ """
+ Tests for ScopedRateThrottle.
+ """
+
+ def setUp(self):
+ class XYScopedRateThrottle(ScopedRateThrottle):
+ TIMER_SECONDS = 0
+ THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
+ timer = lambda self: self.TIMER_SECONDS
+
+ class XView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'x'
+
+ def get(self, request):
+ return Response('x')
+
+ class YView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'y'
+
+ def get(self, request):
+ return Response('y')
+
+ class UnscopedView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+
+ def get(self, request):
+ return Response('y')
+
+ self.throttle_class = XYScopedRateThrottle
+ self.factory = APIRequestFactory()
+ self.x_view = XView.as_view()
+ self.y_view = YView.as_view()
+ self.unscoped_view = UnscopedView.as_view()
+
+ def increment_timer(self, seconds=1):
+ self.throttle_class.TIMER_SECONDS += seconds
+
+ def test_scoped_rate_throttle(self):
+ request = self.factory.get('/')
+
+ # Should be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Ensure throttles properly reset by advancing the rest of the minute
+ self.increment_timer(55)
+
+ # Should still be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should still be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ def test_unscoped_view_not_throttled(self):
+ request = self.factory.get('/')
+
+ for idx in range(10):
+ self.increment_timer()
+ response = self.unscoped_view(request)
+ self.assertEqual(200, response.status_code)
diff --git a/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py b/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py
index 29ed4a961c..8132ec4c8e 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_urlpatterns.py
@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from collections import namedtuple
from django.core import urlresolvers
from django.test import TestCase
-from django.test.client import RequestFactory
+from rest_framework.test import APIRequestFactory
from rest_framework.compat import patterns, url, include
from rest_framework.urlpatterns import format_suffix_patterns
@@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase):
Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
"""
def _resolve_urlpatterns(self, urlpatterns, test_paths):
- factory = RequestFactory()
+ factory = APIRequestFactory()
try:
urlpatterns = format_suffix_patterns(urlpatterns)
except Exception:
diff --git a/awx/lib/site-packages/rest_framework/tests/test_validation.py b/awx/lib/site-packages/rest_framework/tests/test_validation.py
index a6ec0e993d..ebfdff9cd1 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_validation.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_validation.py
@@ -2,10 +2,9 @@ from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import generics, serializers, status
-from rest_framework.tests.utils import RequestFactory
-import json
+from rest_framework.test import APIRequestFactory
-factory = RequestFactory()
+factory = APIRequestFactory()
# Regression for #666
@@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase):
validation on read only fields.
"""
obj = ValidationModel.objects.create(blank_validated_field='')
- request = factory.put('/', json.dumps({}),
- content_type='application/json')
+ request = factory.put('/', {}, format='json')
view = UpdateValidationModel().as_view()
response = view(request, pk=obj.pk).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/awx/lib/site-packages/rest_framework/tests/test_views.py b/awx/lib/site-packages/rest_framework/tests/test_views.py
index 2767d24c80..c0bec5aed1 100644
--- a/awx/lib/site-packages/rest_framework/tests/test_views.py
+++ b/awx/lib/site-packages/rest_framework/tests/test_views.py
@@ -1,17 +1,15 @@
from __future__ import unicode_literals
import copy
-
from django.test import TestCase
-from django.test.client import RequestFactory
-
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
-factory = RequestFactory()
+factory = APIRequestFactory()
class BasicView(APIView):
diff --git a/awx/lib/site-packages/rest_framework/tests/utils.py b/awx/lib/site-packages/rest_framework/tests/utils.py
deleted file mode 100644
index 8c87917d92..0000000000
--- a/awx/lib/site-packages/rest_framework/tests/utils.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from __future__ import unicode_literals
-from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory
-from django.test.client import MULTIPART_CONTENT
-from rest_framework.compat import urlparse
-
-
-class RequestFactory(_RequestFactory):
-
- def __init__(self, **defaults):
- super(RequestFactory, self).__init__(**defaults)
-
- def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
- **extra):
- "Construct a PATCH request."
-
- patch_data = self._encode_data(data, content_type)
-
- parsed = urlparse.urlparse(path)
- r = {
- 'CONTENT_LENGTH': len(patch_data),
- 'CONTENT_TYPE': content_type,
- 'PATH_INFO': self._get_path(parsed),
- 'QUERY_STRING': parsed[4],
- 'REQUEST_METHOD': 'PATCH',
- 'wsgi.input': FakePayload(patch_data),
- }
- r.update(extra)
- return self.request(**r)
-
-
-class Client(_Client, RequestFactory):
- def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
- follow=False, **extra):
- """
- Send a resource to the server using PATCH.
- """
- response = super(Client, self).patch(path, data=data, content_type=content_type, **extra)
- if follow:
- response = self._handle_redirects(response, **extra)
- return response
diff --git a/awx/lib/site-packages/rest_framework/throttling.py b/awx/lib/site-packages/rest_framework/throttling.py
index 93ea9816cb..65b4559307 100644
--- a/awx/lib/site-packages/rest_framework/throttling.py
+++ b/awx/lib/site-packages/rest_framework/throttling.py
@@ -3,7 +3,7 @@ Provides various throttling policies.
"""
from __future__ import unicode_literals
from django.core.cache import cache
-from rest_framework import exceptions
+from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
"""
timer = time.time
- settings = api_settings
cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None
+ THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
@@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
try:
- return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
+ return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
@@ -96,6 +96,9 @@ class SimpleRateThrottle(BaseThrottle):
return True
self.key = self.get_cache_key(request, view)
+ if self.key is None:
+ return True
+
self.history = cache.get(self.key, [])
self.now = self.timer()
@@ -187,6 +190,27 @@ class ScopedRateThrottle(SimpleRateThrottle):
"""
scope_attr = 'throttle_scope'
+ def __init__(self):
+ # Override the usual SimpleRateThrottle, because we can't determine
+ # the rate until called by the view.
+ pass
+
+ def allow_request(self, request, view):
+ # We can only determine the scope once we're called by the view.
+ self.scope = getattr(view, self.scope_attr, None)
+
+ # If a view does not have a `throttle_scope` always allow the request
+ if not self.scope:
+ return True
+
+ # Determine the allowed request rate as we normally would during
+ # the `__init__` call.
+ self.rate = self.get_rate()
+ self.num_requests, self.duration = self.parse_rate(self.rate)
+
+ # We can now proceed as normal.
+ return super(ScopedRateThrottle, self).allow_request(request, view)
+
def get_cache_key(self, request, view):
"""
If `view.throttle_scope` is not set, don't apply this throttle.
@@ -194,18 +218,12 @@ class ScopedRateThrottle(SimpleRateThrottle):
Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view.
"""
- scope = getattr(view, self.scope_attr, None)
-
- if not scope:
- # Only throttle views if `.throttle_scope` is set on the view.
- return None
-
if request.user.is_authenticated():
ident = request.user.id
else:
ident = request.META.get('REMOTE_ADDR', None)
return self.cache_format % {
- 'scope': scope,
+ 'scope': self.scope,
'ident': ident
}
diff --git a/awx/lib/site-packages/rest_framework/utils/formatting.py b/awx/lib/site-packages/rest_framework/utils/formatting.py
index ebadb3a670..4bec838776 100644
--- a/awx/lib/site-packages/rest_framework/utils/formatting.py
+++ b/awx/lib/site-packages/rest_framework/utils/formatting.py
@@ -5,7 +5,7 @@ from __future__ import unicode_literals
from django.utils.html import escape
from django.utils.safestring import mark_safe
-from rest_framework.compat import apply_markdown
+from rest_framework.compat import apply_markdown, smart_text
import re
@@ -63,7 +63,7 @@ def get_view_description(cls, html=False):
Return a description for an `APIView` class or `@api_view` function.
"""
description = cls.__doc__ or ''
- description = _remove_leading_indent(description)
+ description = _remove_leading_indent(smart_text(description))
if html:
return markup_description(description)
return description
diff --git a/awx/lib/site-packages/rest_framework/views.py b/awx/lib/site-packages/rest_framework/views.py
index e1b6705b6d..d51233a932 100644
--- a/awx/lib/site-packages/rest_framework/views.py
+++ b/awx/lib/site-packages/rest_framework/views.py
@@ -4,11 +4,11 @@ Provides an APIView class that is the base of all views in REST framework.
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
-from django.http import Http404, HttpResponse
+from django.http import Http404
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
-from rest_framework.compat import View
+from rest_framework.compat import View, HttpResponseBase
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
@@ -244,9 +244,10 @@ class APIView(View):
Returns the final response object.
"""
# Make the error obvious if a proper response is not returned
- assert isinstance(response, HttpResponse), (
- 'Expected a `Response` to be returned from the view, '
- 'but received a `%s`' % type(response)
+ assert isinstance(response, HttpResponseBase), (
+ 'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
+ 'to be returned from the view, but received a `%s`'
+ % type(response)
)
if isinstance(response, Response):
@@ -268,7 +269,7 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
- if isinstance(exc, exceptions.Throttled):
+ if isinstance(exc, exceptions.Throttled) and exc.wait is not None:
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
@@ -304,10 +305,10 @@ class APIView(View):
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
- request = self.initialize_request(request, *args, **kwargs)
- self.request = request
self.args = args
self.kwargs = kwargs
+ request = self.initialize_request(request, *args, **kwargs)
+ self.request = request
self.headers = self.default_response_headers # deprecate?
try:
@@ -341,8 +342,15 @@ class APIView(View):
Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests.
"""
+
+ # This is used by ViewSets to disambiguate instance vs list views
+ view_name_suffix = getattr(self, 'suffix', None)
+
+ # By default we can't provide any form-like information, however the
+ # generic views override this implementation and add additional
+ # information for POST and PUT methods, based on the serializer.
ret = SortedDict()
- ret['name'] = get_view_name(self.__class__)
+ ret['name'] = get_view_name(self.__class__, view_name_suffix)
ret['description'] = get_view_description(self.__class__)
ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
ret['parses'] = [parser.media_type for parser in self.parser_classes]
diff --git a/awx/lib/site-packages/south/__init__.py b/awx/lib/site-packages/south/__init__.py
index 20c39178e6..2921dad23d 100644
--- a/awx/lib/site-packages/south/__init__.py
+++ b/awx/lib/site-packages/south/__init__.py
@@ -2,7 +2,7 @@
South - Useable migrations for Django apps
"""
-__version__ = "0.8.1"
+__version__ = "0.8.2"
__authors__ = [
"Andrew Godwin <andrew@aeracode.org>",
"Andy McCurdy <andy@andymccurdy.com>"
diff --git a/awx/lib/site-packages/south/creator/actions.py b/awx/lib/site-packages/south/creator/actions.py
index 37586c23ca..2ffc8ca19f 100644
--- a/awx/lib/site-packages/south/creator/actions.py
+++ b/awx/lib/site-packages/south/creator/actions.py
@@ -137,12 +137,14 @@ class _NullIssuesField(object):
A field that might need to ask a question about rogue NULL values.
"""
- allow_third_null_option = False
+ issue_with_backward_migration = False
irreversible = False
IRREVERSIBLE_TEMPLATE = '''
# User chose to not deal with backwards NULL issues for '%(model_name)s.%(field_name)s'
- raise RuntimeError("Cannot reverse this migration. '%(model_name)s.%(field_name)s' and its values cannot be restored.")'''
+ raise RuntimeError("Cannot reverse this migration. '%(model_name)s.%(field_name)s' and its values cannot be restored.")
+
+ # The following code is provided here to aid in writing a correct migration'''
def deal_with_not_null_no_default(self, field, field_def):
# If it's a CharField or TextField that's blank, skip this step.
@@ -156,17 +158,17 @@ class _NullIssuesField(object):
))
print(" ? Since you are %s, you MUST specify a default" % self.null_reason)
print(" ? value to use for existing rows. Would you like to:")
- print(" ? 1. Quit now, and add a default to the field in models.py")
+ print(" ? 1. Quit now"+("." if self.issue_with_backward_migration else ", and add a default to the field in models.py" ))
print(" ? 2. Specify a one-off value to use for existing columns now")
- if self.allow_third_null_option:
- print(" ? 3. Disable the backwards migration by raising an exception.")
+ if self.issue_with_backward_migration:
+ print(" ? 3. Disable the backwards migration by raising an exception; you can edit the migration to fix it later")
while True:
choice = raw_input(" ? Please select a choice: ")
if choice == "1":
sys.exit(1)
elif choice == "2":
break
- elif choice == "3" and self.allow_third_null_option:
+ elif choice == "3" and self.issue_with_backward_migration:
break
else:
print(" ! Invalid choice.")
@@ -266,7 +268,7 @@ class DeleteField(AddField):
"""
null_reason = "removing this field"
- allow_third_null_option = True
+ issue_with_backward_migration = True
def console_line(self):
"Returns the string to print on the console, e.g. ' + Added field foo'"
@@ -283,7 +285,7 @@ class DeleteField(AddField):
if not self.irreversible:
return AddField.forwards_code(self)
else:
- return self.irreversable_code(self.field)
+ return self.irreversable_code(self.field) + AddField.forwards_code(self)
class ChangeField(Action, _NullIssuesField):
@@ -315,7 +317,7 @@ class ChangeField(Action, _NullIssuesField):
self.deal_with_not_null_no_default(self.new_field, self.new_def)
if not self.old_field.null and self.new_field.null and not old_default:
self.null_reason = "making this field nullable"
- self.allow_third_null_option = True
+ self.issue_with_backward_migration = True
self.deal_with_not_null_no_default(self.old_field, self.old_def)
def console_line(self):
@@ -353,10 +355,11 @@ class ChangeField(Action, _NullIssuesField):
return self._code(self.old_field, self.new_field, self.new_def)
def backwards_code(self):
+ change_code = self._code(self.new_field, self.old_field, self.old_def)
if not self.irreversible:
- return self._code(self.new_field, self.old_field, self.old_def)
+ return change_code
else:
- return self.irreversable_code(self.old_field)
+ return self.irreversable_code(self.old_field) + change_code
class AddUnique(Action):
diff --git a/awx/lib/site-packages/south/creator/changes.py b/awx/lib/site-packages/south/creator/changes.py
index 9a11d9eb3b..6cdbd19de0 100644
--- a/awx/lib/site-packages/south/creator/changes.py
+++ b/awx/lib/site-packages/south/creator/changes.py
@@ -309,21 +309,21 @@ class AutoChanges(BaseChanges):
old_together = [old_together]
if new_together and isinstance(new_together[0], string_types):
new_together = [new_together]
- old_together = list(map(set, old_together))
- new_together = list(map(set, new_together))
+ old_together = frozenset(tuple(o) for o in old_together)
+ new_together = frozenset(tuple(n) for n in new_together)
# See if any appeared or disappeared
- for item in old_together:
- if item not in new_together:
- yield (del_operation, {
- "model": self.old_orm[key],
- "fields": [self.old_orm[key + ":" + x] for x in item],
- })
- for item in new_together:
- if item not in old_together:
- yield (add_operation, {
- "model": self.current_model_from_key(key),
- "fields": [self.current_field_from_key(key, x) for x in item],
- })
+ disappeared = old_together.difference(new_together)
+ appeared = new_together.difference(old_together)
+ for item in disappeared:
+ yield (del_operation, {
+ "model": self.old_orm[key],
+ "fields": [self.old_orm[key + ":" + x] for x in item],
+ })
+ for item in appeared:
+ yield (add_operation, {
+ "model": self.current_model_from_key(key),
+ "fields": [self.current_field_from_key(key, x) for x in item],
+ })
@classmethod
def is_triple(cls, triple):
diff --git a/awx/lib/site-packages/south/db/__init__.py b/awx/lib/site-packages/south/db/__init__.py
index 9927c27f0c..b9b7168e62 100644
--- a/awx/lib/site-packages/south/db/__init__.py
+++ b/awx/lib/site-packages/south/db/__init__.py
@@ -12,7 +12,8 @@ engine_modules = {
'django.db.backends.mysql': 'mysql',
'mysql_oursql.standard': 'mysql',
'django.db.backends.oracle': 'oracle',
- 'sql_server.pyodbc': 'sql_server.pyodbc', #django-pyodbc
+ 'sql_server.pyodbc': 'sql_server.pyodbc', #django-pyodbc-azure
+ 'django_pyodbc': 'sql_server.pyodbc', #django-pyodbc
'sqlserver_ado': 'sql_server.pyodbc', #django-mssql
'firebird': 'firebird', #django-firebird
'django.contrib.gis.db.backends.postgis': 'postgresql_psycopg2',
diff --git a/awx/lib/site-packages/south/db/firebird.py b/awx/lib/site-packages/south/db/firebird.py
index c55a82517f..6216a52cd5 100644
--- a/awx/lib/site-packages/south/db/firebird.py
+++ b/awx/lib/site-packages/south/db/firebird.py
@@ -71,11 +71,11 @@ class DatabaseOperations(generic.DatabaseOperations):
def _alter_set_defaults(self, field, name, params, sqls):
"Subcommand of alter_column that sets default values (overrideable)"
- # Next, set any default
- if not field.null and field.has_default():
- default = field.get_default()
- sqls.append(('ALTER COLUMN %s SET DEFAULT %%s ' % (self.quote_name(name),), [default]))
- elif self._column_has_default(params):
+ # Historically, we used to set defaults here.
+ # But since South 0.8, we don't ever set defaults on alter-column -- we only
+ # use database-level defaults as scaffolding when adding columns.
+ # However, we still sometimes need to remove defaults in alter-column.
+ if self._column_has_default(params):
sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), []))
diff --git a/awx/lib/site-packages/south/db/generic.py b/awx/lib/site-packages/south/db/generic.py
index 1a26d955b5..5c1935444d 100644
--- a/awx/lib/site-packages/south/db/generic.py
+++ b/awx/lib/site-packages/south/db/generic.py
@@ -444,12 +444,11 @@ class DatabaseOperations(object):
def _alter_set_defaults(self, field, name, params, sqls):
"Subcommand of alter_column that sets default values (overrideable)"
- # Next, set any default
- if not field.null and field.has_default():
- default = field.get_db_prep_save(field.get_default(), connection=self._get_connection())
- sqls.append(('ALTER COLUMN %s SET DEFAULT %%s ' % (self.quote_name(name),), [default]))
- else:
- sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), []))
+ # Historically, we used to set defaults here.
+ # But since South 0.8, we don't ever set defaults on alter-column -- we only
+ # use database-level defaults as scaffolding when adding columns.
+ # However, we still sometimes need to remove defaults in alter-column.
+ sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), []))
def _update_nulls_to_default(self, params, field):
"Subcommand of alter_column that updates nulls to default value (overrideable)"
@@ -835,8 +834,13 @@ class DatabaseOperations(object):
# If there is just one column in the index, use a default algorithm from Django
if len(column_names) == 1 and not suffix:
+ try:
+ _hash = self._digest([column_names[0]])
+ except TypeError:
+ # Django < 1.5 backward compatibility.
+ _hash = self._digest(column_names[0])
return self.shorten_name(
- '%s_%s' % (table_name, self._digest(column_names[0]))
+ '%s_%s' % (table_name, _hash),
)
# Else generate the name for the index by South
diff --git a/awx/lib/site-packages/south/db/oracle.py b/awx/lib/site-packages/south/db/oracle.py
index 6e002945ff..cb4148b492 100644
--- a/awx/lib/site-packages/south/db/oracle.py
+++ b/awx/lib/site-packages/south/db/oracle.py
@@ -86,9 +86,12 @@ class DatabaseOperations(generic.DatabaseOperations):
for field_name, field in fields:
+ field = self._field_sanity(field)
+
# avoid default values in CREATE TABLE statements (#925)
field._suppress_default = True
+
col = self.column_sql(table_name, field_name, field)
if not col:
continue
@@ -159,12 +162,12 @@ END;
'nullity': 'NOT NULL',
'default': 'NULL'
}
- if field.null:
+ if field.null:
params['nullity'] = 'NULL'
sql_templates = [
- (self.alter_string_set_type, params),
- (self.alter_string_set_default, params),
+ (self.alter_string_set_type, params, []),
+ (self.alter_string_set_default, params, []),
]
if not field.null and field.has_default():
# Use default for rows that had nulls. To support the case where
@@ -177,8 +180,8 @@ END;
p.update(kw)
return p
sql_templates[:0] = [
- (self.alter_string_set_type, change_params(nullity='NULL')),
- (self.alter_string_update_nulls_to_default, change_params(default=self._default_value_workaround(field.get_default()))),
+ (self.alter_string_set_type, change_params(nullity='NULL'),[]),
+ (self.alter_string_update_nulls_to_default, change_params(default="%s"), [field.get_default()]),
]
@@ -191,9 +194,9 @@ END;
'constraint': self.quote_name(constraint),
})
- for sql_template, params in sql_templates:
+ for sql_template, params, args in sql_templates:
try:
- self.execute(sql_template % params, print_all_errors=False)
+ self.execute(sql_template % params, args, print_all_errors=False)
except DatabaseError as exc:
description = str(exc)
# Oracle complains if a column is already NULL/NOT NULL
@@ -250,6 +253,7 @@ END;
@generic.invalidate_table_constraints
def add_column(self, table_name, name, field, keep_default=False):
+ field = self._field_sanity(field)
sql = self.column_sql(table_name, name, field)
sql = self.adj_column_sql(sql)
@@ -288,7 +292,11 @@ END;
"""
if isinstance(field, models.BooleanField) and field.has_default():
field.default = int(field.to_python(field.get_default()))
+ # On Oracle, empty strings are null
+ if isinstance(field, (models.CharField, models.TextField)):
+ field.null = field.empty_strings_allowed
return field
+
def _default_value_workaround(self, value):
from datetime import date,time,datetime
diff --git a/awx/lib/site-packages/south/db/sql_server/pyodbc.py b/awx/lib/site-packages/south/db/sql_server/pyodbc.py
index 1b200ad13b..b725ec0da6 100644
--- a/awx/lib/site-packages/south/db/sql_server/pyodbc.py
+++ b/awx/lib/site-packages/south/db/sql_server/pyodbc.py
@@ -242,19 +242,15 @@ class DatabaseOperations(generic.DatabaseOperations):
def _alter_set_defaults(self, field, name, params, sqls):
"Subcommand of alter_column that sets default values (overrideable)"
- # First drop the current default if one exists
+ # Historically, we used to set defaults here.
+ # But since South 0.8, we don't ever set defaults on alter-column -- we only
+ # use database-level defaults as scaffolding when adding columns.
+ # However, we still sometimes need to remove defaults in alter-column.
table_name = self.quote_name(params['table_name'])
drop_default = self.drop_column_default_sql(table_name, name)
if drop_default:
sqls.append((drop_default, []))
- # Next, set any default
-
- if field.has_default():
- default = field.get_default()
- literal = self._value_to_unquoted_literal(field, default)
- sqls.append(('ADD DEFAULT %s for %s' % (self._quote_string(literal), self.quote_name(name),), []))
-
def _value_to_unquoted_literal(self, field, value):
# Start with the field's own translation
conn = self._get_connection()
@@ -432,6 +428,7 @@ class DatabaseOperations(generic.DatabaseOperations):
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
INNER JOIN sys.indexes i ON i.object_id = t.object_id
INNER JOIN sys.index_columns ic ON ic.object_id = t.object_id
+ AND ic.index_id = i.index_id
INNER JOIN sys.columns c ON c.object_id = t.object_id
AND ic.column_id = c.column_id
WHERE i.is_unique=0 AND i.is_primary_key=0 AND i.is_unique_constraint=0
diff --git a/awx/lib/site-packages/south/db/sqlite3.py b/awx/lib/site-packages/south/db/sqlite3.py
index db45511456..ecdb2ce6d2 100644
--- a/awx/lib/site-packages/south/db/sqlite3.py
+++ b/awx/lib/site-packages/south/db/sqlite3.py
@@ -31,7 +31,7 @@ class DatabaseOperations(generic.DatabaseOperations):
field_default = None
if not getattr(field, '_suppress_default', False):
default = field.get_default()
- if default is not None and default!='':
+ if default is not None:
field_default = "'%s'" % field.get_db_prep_save(default, connection=self._get_connection())
field._suppress_default = True
self._remake_table(table_name, added={
@@ -136,7 +136,7 @@ class DatabaseOperations(generic.DatabaseOperations):
continue
src_fields_new.append(self.quote_name(field))
for field, (_,default) in added.items():
- if default is not None and default!='':
+ if default is not None:
field = self.quote_name(field)
src_fields_new.append("%s as %s" % (default, field))
dst_fields_new.append(field)
@@ -263,10 +263,3 @@ class DatabaseOperations(generic.DatabaseOperations):
# No cascades on deletes
def delete_table(self, table_name, cascade=True):
generic.DatabaseOperations.delete_table(self, table_name, False)
-
- def _default_value_workaround(self, default):
- if default == True:
- default = 1
- elif default == False:
- default = 0
- return default
diff --git a/awx/lib/site-packages/south/hacks/django_1_0.py b/awx/lib/site-packages/south/hacks/django_1_0.py
index 00d0a8bab3..e4a60c66a9 100644
--- a/awx/lib/site-packages/south/hacks/django_1_0.py
+++ b/awx/lib/site-packages/south/hacks/django_1_0.py
@@ -2,6 +2,7 @@
Hacks for the Django 1.0/1.0.2 releases.
"""
+import django
from django.conf import settings
from django.db.backends.creation import BaseDatabaseCreation
from django.db.models.loading import cache
@@ -51,7 +52,7 @@ class Hacks:
Used to repopulate AppCache after fiddling with INSTALLED_APPS.
"""
cache.loaded = False
- cache.handled = {}
+ cache.handled = set() if django.VERSION >= (1, 6) else {}
cache.postponed = []
cache.app_store = SortedDict()
cache.app_models = SortedDict()
diff --git a/awx/lib/site-packages/south/management/commands/datamigration.py b/awx/lib/site-packages/south/management/commands/datamigration.py
index 08c1d0891c..9254e56c15 100644
--- a/awx/lib/site-packages/south/management/commands/datamigration.py
+++ b/awx/lib/site-packages/south/management/commands/datamigration.py
@@ -34,6 +34,8 @@ class Command(BaseCommand):
usage_str = "Usage: ./manage.py datamigration appname migrationname [--stdout] [--freeze appname]"
def handle(self, app=None, name="", freeze_list=None, stdout=False, verbosity=1, **options):
+
+ verbosity = int(verbosity)
# Any supposed lists that are None become empty lists
freeze_list = freeze_list or []
@@ -46,13 +48,19 @@ class Command(BaseCommand):
if re.search('[^_\w]', name) and name != "-":
self.error("Migration names should contain only alphanumeric characters and underscores.")
- # if not name, there's an error
+ # If not name, there's an error
if not name:
- self.error("You must provide a name for this migration\n" + self.usage_str)
+ self.error("You must provide a name for this migration.\n" + self.usage_str)
if not app:
self.error("You must provide an app to create a migration for.\n" + self.usage_str)
-
+
+ # Ensure that verbosity is not a string (Python 3)
+ try:
+ verbosity = int(verbosity)
+ except ValueError:
+ self.error("Verbosity must be an number.\n" + self.usage_str)
+
# Get the Migrations for this app (creating the migrations dir if needed)
migrations = Migrations(app, force_creation=True, verbose_creation=verbosity > 0)
@@ -63,7 +71,7 @@ class Command(BaseCommand):
apps_to_freeze = self.calc_frozen_apps(migrations, freeze_list)
# So, what's in this file, then?
- file_contents = MIGRATION_TEMPLATE % {
+ file_contents = self.get_migration_template() % {
"frozen_models": freezer.freeze_apps_to_string(apps_to_freeze),
"complete_apps": apps_to_freeze and "complete_apps = [%s]" % (", ".join(map(repr, apps_to_freeze))) or ""
}
@@ -103,6 +111,9 @@ class Command(BaseCommand):
print(message, file=sys.stderr)
sys.exit(code)
+ def get_migration_template(self):
+ return MIGRATION_TEMPLATE
+
MIGRATION_TEMPLATE = """# -*- coding: utf-8 -*-
import datetime
diff --git a/awx/lib/site-packages/south/management/commands/schemamigration.py b/awx/lib/site-packages/south/management/commands/schemamigration.py
index e29fc620b6..514f4a79d9 100644
--- a/awx/lib/site-packages/south/management/commands/schemamigration.py
+++ b/awx/lib/site-packages/south/management/commands/schemamigration.py
@@ -168,7 +168,7 @@ class Command(DataCommand):
apps_to_freeze = self.calc_frozen_apps(migrations, freeze_list)
# So, what's in this file, then?
- file_contents = MIGRATION_TEMPLATE % {
+ file_contents = self.get_migration_template() % {
"forwards": "\n".join(forwards_actions or [" pass"]),
"backwards": "\n".join(backwards_actions or [" pass"]),
"frozen_models": freezer.freeze_apps_to_string(apps_to_freeze),
@@ -205,6 +205,9 @@ class Command(DataCommand):
else:
print("%s %s. You can now apply this migration with: ./manage.py migrate %s" % (verb, new_filename, app), file=sys.stderr)
+ def get_migration_template(self):
+ return MIGRATION_TEMPLATE
+
MIGRATION_TEMPLATE = """# -*- coding: utf-8 -*-
import datetime
diff --git a/awx/lib/site-packages/south/management/commands/syncdb.py b/awx/lib/site-packages/south/management/commands/syncdb.py
index 702085b194..17fc22cbfc 100644
--- a/awx/lib/site-packages/south/management/commands/syncdb.py
+++ b/awx/lib/site-packages/south/management/commands/syncdb.py
@@ -98,6 +98,8 @@ class Command(NoArgsCommand):
if options.get('migrate', True):
if verbosity:
print("Migrating...")
+ # convert from store_true to store_false
+ options['no_initial_data'] = not options.get('load_initial_data', True)
management.call_command('migrate', **options)
# Be obvious about what we did
diff --git a/awx/lib/site-packages/south/migration/migrators.py b/awx/lib/site-packages/south/migration/migrators.py
index 1be895dcf3..98b3489165 100644
--- a/awx/lib/site-packages/south/migration/migrators.py
+++ b/awx/lib/site-packages/south/migration/migrators.py
@@ -5,10 +5,6 @@ import datetime
import inspect
import sys
import traceback
-try:
- from cStringIO import StringIO # python 2
-except ImportError:
- from io import StringIO # python 3
from django.core.management import call_command
from django.core.management.commands import loaddata
@@ -19,6 +15,7 @@ from south import exceptions
from south.db import DEFAULT_DB_ALIAS
from south.models import MigrationHistory
from south.signals import ran_migration
+from south.utils.py3 import StringIO
class Migrator(object):
diff --git a/awx/lib/site-packages/south/orm.py b/awx/lib/site-packages/south/orm.py
index 1e5d56ed79..8d46ee7194 100644
--- a/awx/lib/site-packages/south/orm.py
+++ b/awx/lib/site-packages/south/orm.py
@@ -181,12 +181,16 @@ class _FakeORM(object):
"Evaluates the given code in the context of the migration file."
# Drag in the migration module's locals (hopefully including models.py)
- fake_locals = dict(inspect.getmodule(self.cls).__dict__)
-
- # Remove all models from that (i.e. from modern models.py), to stop pollution
- for key, value in fake_locals.items():
- if isinstance(value, type) and issubclass(value, models.Model) and hasattr(value, "_meta"):
- del fake_locals[key]
+ # excluding all models from that (i.e. from modern models.py), to stop pollution
+ fake_locals = dict(
+ (key, value)
+ for key, value in inspect.getmodule(self.cls).__dict__.items()
+ if not (
+ isinstance(value, type)
+ and issubclass(value, models.Model)
+ and hasattr(value, "_meta")
+ )
+ )
# We add our models into the locals for the eval
fake_locals.update(dict([
diff --git a/awx/lib/site-packages/south/tests/db.py b/awx/lib/site-packages/south/tests/db.py
index 90f62fc0be..353677ebed 100644
--- a/awx/lib/site-packages/south/tests/db.py
+++ b/awx/lib/site-packages/south/tests/db.py
@@ -360,7 +360,57 @@ class TestOperations(unittest.TestCase):
db.execute("INSERT INTO test_altercd (eggs) values (12)")
null = db.execute("SELECT spam FROM test_altercd")[0][0]
self.assertFalse(null, "Default for char field was installed into database")
+
+ # Change again to a column with default and not null
+ db.alter_column("test_altercd", "spam", models.CharField(max_length=30, default="loof", null=False))
+ # Assert the default is not in the database
+ if 'oracle' in db.backend_name:
+ # Oracle special treatment -- nulls are always allowed in char columns, so
+ # inserting doesn't raise an integrity error; so we check again as above
+ db.execute("DELETE FROM test_altercd")
+ db.execute("INSERT INTO test_altercd (eggs) values (12)")
+ null = db.execute("SELECT spam FROM test_altercd")[0][0]
+ self.assertFalse(null, "Default for char field was installed into database")
+ else:
+ # For other backends, insert should now just fail
+ self.assertRaises(IntegrityError,
+ db.execute, "INSERT INTO test_altercd (eggs) values (12)")
+
+ @skipIf('oracle' in db.backend_name, "Oracle does not differentiate empty trings from null")
+ def test_default_empty_string(self):
+ """
+ Test altering column defaults with char fields
+ """
+ db.create_table("test_cd_empty", [
+ ('spam', models.CharField(max_length=30, default='')),
+ ('eggs', models.CharField(max_length=30)),
+ ])
+ # Create a record
+ db.execute("INSERT INTO test_cd_empty (spam, eggs) values ('1','2')")
+ # Add a column
+ db.add_column("test_cd_empty", "ham", models.CharField(max_length=30, default=''))
+ empty = db.execute("SELECT ham FROM test_cd_empty")[0][0]
+ self.assertEquals(empty, "", "Empty Default for char field isn't empty string")
+
+ @skipUnless('oracle' in db.backend_name, "Oracle does not differentiate empty trings from null")
+ def test_oracle_strings_null(self):
+ """
+ Test that under Oracle, CherFields are created as null even when specified not-null,
+ because otherwise they would not be able to hold empty strings (which Oracle equates
+ with nulls).
+ Verify fix of #1269.
+ """
+ db.create_table("test_ora_char_nulls", [
+ ('spam', models.CharField(max_length=30, null=True)),
+ ('eggs', models.CharField(max_length=30)),
+ ])
+ db.add_column("test_ora_char_nulls", "ham", models.CharField(max_length=30))
+ db.alter_column("test_ora_char_nulls", "spam", models.CharField(max_length=30, null=False))
+ # So, by the look of it, we should now have three not-null columns
+ db.execute("INSERT INTO test_ora_char_nulls VALUES (NULL, NULL, NULL)")
+
+
def test_mysql_defaults(self):
"""
Test MySQL default handling for BLOB and TEXT.
@@ -568,8 +618,8 @@ class TestOperations(unittest.TestCase):
db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (1, 42)")
except:
self.fail("Looks like multi-field unique constraint applied to only one field.")
- db.start_transaction()
db.rollback_transaction()
+ db.start_transaction()
try:
db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (0, 43)")
except:
@@ -581,17 +631,17 @@ class TestOperations(unittest.TestCase):
except:
pass
else:
- self.fail("Could insert the same integer twice into a unique field.")
+ self.fail("Could insert the same pair twice into unique-together fields.")
db.rollback_transaction()
# Altering one column should not drop or modify multi-column constraint
- db.alter_column("test_alter_unique2", "eggs", models.CharField(max_length=10))
+ db.alter_column("test_alter_unique2", "eggs", models.PositiveIntegerField())
db.start_transaction()
try:
db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (1, 42)")
except:
self.fail("Altering one column broken multi-column unique constraint.")
- db.start_transaction()
db.rollback_transaction()
+ db.start_transaction()
try:
db.execute("INSERT INTO test_alter_unique2 (spam, eggs) VALUES (0, 43)")
except:
@@ -603,7 +653,7 @@ class TestOperations(unittest.TestCase):
except:
pass
else:
- self.fail("Could insert the same integer twice into a unique field after alter_column with unique=False.")
+ self.fail("Could insert the same pair twice into unique-together fields after alter_column with unique=False.")
db.rollback_transaction()
db.delete_table("test_alter_unique2")
db.start_transaction()
@@ -811,10 +861,39 @@ class TestOperations(unittest.TestCase):
db.add_column("test_fk", 'foreik', models.ForeignKey(User, null=True))
db.execute_deferred_sql()
- # Make the FK null
+ # Make the FK not null
db.alter_column("test_fk", "foreik_id", models.ForeignKey(User))
db.execute_deferred_sql()
+ def test_make_foreign_key_null(self):
+ # Table for FK to target
+ User = db.mock_model(model_name='User', db_table='auth_user', db_tablespace='', pk_field_name='id', pk_field_type=models.AutoField, pk_field_args=[], pk_field_kwargs={})
+ # Table with no foreign key
+ db.create_table("test_make_fk_null", [
+ ('eggs', models.IntegerField()),
+ ('foreik', models.ForeignKey(User))
+ ])
+ db.execute_deferred_sql()
+
+ # Make the FK null
+ db.alter_column("test_make_fk_null", "foreik_id", models.ForeignKey(User, null=True))
+ db.execute_deferred_sql()
+
+ def test_alter_double_indexed_column(self):
+ # Table for FK to target
+ User = db.mock_model(model_name='User', db_table='auth_user', db_tablespace='', pk_field_name='id', pk_field_type=models.AutoField, pk_field_args=[], pk_field_kwargs={})
+ # Table with no foreign key
+ db.create_table("test_2indexed", [
+ ('eggs', models.IntegerField()),
+ ('foreik', models.ForeignKey(User))
+ ])
+ db.create_unique("test_2indexed", ["eggs", "foreik_id"])
+ db.execute_deferred_sql()
+
+ # Make the FK null
+ db.alter_column("test_2indexed", "foreik_id", models.ForeignKey(User, null=True))
+ db.execute_deferred_sql()
+
class TestCacheGeneric(unittest.TestCase):
base_ops_cls = generic.DatabaseOperations
def setUp(self):
diff --git a/awx/lib/site-packages/south/utils/__init__.py b/awx/lib/site-packages/south/utils/__init__.py
index c3c5191633..8d7297ea5d 100644
--- a/awx/lib/site-packages/south/utils/__init__.py
+++ b/awx/lib/site-packages/south/utils/__init__.py
@@ -5,13 +5,13 @@ Generally helpful utility functions.
def _ask_for_it_by_name(name):
"Returns an object referenced by absolute path."
- bits = name.split(".")
+ bits = str(name).split(".")
## what if there is no absolute reference?
- if len(bits)>1:
+ if len(bits) > 1:
modulename = ".".join(bits[:-1])
else:
- modulename=bits[0]
+ modulename = bits[0]
module = __import__(modulename, {}, {}, bits[-1])
diff --git a/awx/lib/site-packages/south/utils/py3.py b/awx/lib/site-packages/south/utils/py3.py
index 9c5baaded0..732e9043a8 100644
--- a/awx/lib/site-packages/south/utils/py3.py
+++ b/awx/lib/site-packages/south/utils/py3.py
@@ -11,11 +11,18 @@ if PY3:
text_type = str
raw_input = input
+ import io
+ StringIO = io.StringIO
+
else:
string_types = basestring,
text_type = unicode
raw_input = raw_input
+ import cStringIO
+ StringIO = cStringIO.StringIO
+
+
def with_metaclass(meta, base=object):
"""Create a base class with a metaclass."""
return meta("NewBase", (base,), {})
diff --git a/awx/lib/site-packages/taggit/__init__.py b/awx/lib/site-packages/taggit/__init__.py
index 4c4993466d..ca04bbe69d 100644
--- a/awx/lib/site-packages/taggit/__init__.py
+++ b/awx/lib/site-packages/taggit/__init__.py
@@ -1 +1 @@
-VERSION = (0, 10, 0, 'alpha', 1)
+VERSION = (0, 10, 0)
diff --git a/awx/lib/site-packages/taggit/admin.py b/awx/lib/site-packages/taggit/admin.py
index 6c012d6fe3..0498c9ddf4 100644
--- a/awx/lib/site-packages/taggit/admin.py
+++ b/awx/lib/site-packages/taggit/admin.py
@@ -1,3 +1,5 @@
+from __future__ import unicode_literals
+
from django.contrib import admin
from taggit.models import Tag, TaggedItem
diff --git a/awx/lib/site-packages/taggit/forms.py b/awx/lib/site-packages/taggit/forms.py
index e0198bd933..cb372f22a0 100644
--- a/awx/lib/site-packages/taggit/forms.py
+++ b/awx/lib/site-packages/taggit/forms.py
@@ -1,12 +1,15 @@
+from __future__ import unicode_literals
+
from django import forms
from django.utils.translation import ugettext as _
+from django.utils import six
from taggit.utils import parse_tags, edit_string_for_tags
class TagWidget(forms.TextInput):
def render(self, name, value, attrs=None):
- if value is not None and not isinstance(value, basestring):
+ if value is not None and not isinstance(value, six.string_types):
value = edit_string_for_tags([o.tag for o in value.select_related("tag")])
return super(TagWidget, self).render(name, value, attrs)
diff --git a/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.mo b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.mo
new file mode 100644
index 0000000000..9ce13fb2c2
--- /dev/null
+++ b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.mo
Binary files differ
diff --git a/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.po b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.po
new file mode 100644
index 0000000000..13262e1c78
--- /dev/null
+++ b/awx/lib/site-packages/taggit/locale/cs/LC_MESSAGES/django.po
@@ -0,0 +1,64 @@
+# SOME DESCRIPTIVE TITLE.
+# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
+# This file is distributed under the same license as the PACKAGE package.
+# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
+#
+#, fuzzy
+msgid ""
+msgstr ""
+"Project-Id-Version: PACKAGE VERSION\n"
+"Report-Msgid-Bugs-To: \n"
+"POT-Creation-Date: 2013-08-01 16:52+0200\n"
+"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
+"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
+"Language-Team: LANGUAGE <LL@li.org>\n"
+"Language: \n"
+"MIME-Version: 1.0\n"
+"Content-Type: text/plain; charset=UTF-8\n"
+"Content-Transfer-Encoding: 8bit\n"
+"Plural-Forms: nplurals=3; plural=(n==1) ? 0 : (n>=2 && n<=4) ? 1 : 2;\n"
+
+#: forms.py:24
+msgid "Please provide a comma-separated list of tags."
+msgstr "Vložte čárkami oddělený seznam tagů"
+
+#: managers.py:59 models.py:59
+msgid "Tags"
+msgstr "Tagy"
+
+#: managers.py:60
+msgid "A comma-separated list of tags."
+msgstr "Čárkami oddělený seznam tagů"
+
+#: models.py:15
+msgid "Name"
+msgstr "Jméno"
+
+#: models.py:16
+msgid "Slug"
+msgstr "Slug"
+
+#: models.py:58
+msgid "Tag"
+msgstr "Tag"
+
+#: models.py:65
+#, python-format
+msgid "%(object)s tagged with %(tag)s"
+msgstr "%(object)s označen tagem %(tag)s"
+
+#: models.py:112
+msgid "Object id"
+msgstr "ID objektu"
+
+#: models.py:115
+msgid "Content type"
+msgstr "Typ obsahu"
+
+#: models.py:158
+msgid "Tagged Item"
+msgstr "Tagem označená položka"
+
+#: models.py:159
+msgid "Tagged Items"
+msgstr "Tagy označené položky"
diff --git a/awx/lib/site-packages/taggit/managers.py b/awx/lib/site-packages/taggit/managers.py
index ca1700fd96..56fec6064c 100644
--- a/awx/lib/site-packages/taggit/managers.py
+++ b/awx/lib/site-packages/taggit/managers.py
@@ -1,57 +1,78 @@
+from __future__ import unicode_literals
+
+from django import VERSION
from django.contrib.contenttypes.generic import GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models
+from django.db.models.fields import Field
from django.db.models.fields.related import ManyToManyRel, RelatedField, add_lazy_relation
from django.db.models.related import RelatedObject
from django.utils.text import capfirst
from django.utils.translation import ugettext_lazy as _
+from django.utils import six
+
+try:
+ from django.db.models.related import PathInfo
+except ImportError:
+ pass # PathInfo is not used on Django < 1.6
from taggit.forms import TagField
from taggit.models import TaggedItem, GenericTaggedItemBase
from taggit.utils import require_instance_manager
-try:
- all
-except NameError:
- # 2.4 compat
- try:
- from django.utils.itercompat import all
- except ImportError:
- # 1.1.X compat
- def all(iterable):
- for item in iterable:
- if not item:
- return False
- return True
+def _model_name(model):
+ if VERSION < (1, 7):
+ return model._meta.module_name
+ else:
+ return model._meta.model_name
class TaggableRel(ManyToManyRel):
- def __init__(self):
+ def __init__(self, field):
self.related_name = None
self.limit_choices_to = {}
self.symmetrical = True
self.multiple = True
self.through = None
+ self.field = field
+
+ def get_joining_columns(self):
+ return self.field.get_reverse_joining_columns()
+
+ def get_extra_restriction(self, where_class, alias, related_alias):
+ return self.field.get_extra_restriction(where_class, related_alias, alias)
+
+class ExtraJoinRestriction(object):
+ """
+ An extra restriction used for contenttype restriction in joins.
+ """
+ def __init__(self, alias, col, content_types):
+ self.alias = alias
+ self.col = col
+ self.content_types = content_types
-class TaggableManager(RelatedField):
+ def as_sql(self, qn, connection):
+ if len(self.content_types) == 1:
+ extra_where = "%s.%s = %%s" % (qn(self.alias), qn(self.col))
+ params = [self.content_types[0]]
+ else:
+ extra_where = "%s.%s IN (%s)" % (qn(self.alias), qn(self.col),
+ ','.join(['%s'] * len(self.content_types)))
+ params = self.content_types
+ return extra_where, params
+
+ def relabel_aliases(self, change_map):
+ self.alias = change_map.get(self.alias, self.alias)
+
+
+class TaggableManager(RelatedField, Field):
def __init__(self, verbose_name=_("Tags"),
help_text=_("A comma-separated list of tags."), through=None, blank=False):
+ Field.__init__(self, verbose_name=verbose_name, help_text=help_text, blank=blank)
self.through = through or TaggedItem
- self.rel = TaggableRel()
- self.verbose_name = verbose_name
- self.help_text = help_text
- self.blank = blank
- self.editable = True
- self.unique = False
- self.creates_table = False
- self.db_column = None
- self.choices = None
- self.serialize = False
- self.null = True
- self.creation_counter = models.Field.creation_counter
- models.Field.creation_counter += 1
+ self.rel = TaggableRel(self)
def __get__(self, instance, model):
if instance is not None and instance.pk is None:
@@ -63,12 +84,15 @@ class TaggableManager(RelatedField):
return manager
def contribute_to_class(self, cls, name):
- self.name = self.column = name
+ if VERSION < (1, 7):
+ self.name = self.column = name
+ else:
+ self.set_attributes_from_name(name)
self.model = cls
cls._meta.add_field(self)
setattr(cls, name, self)
if not cls._meta.abstract:
- if isinstance(self.through, basestring):
+ if isinstance(self.through, six.string_types):
def resolve_related_class(field, model, cls):
self.through = model
self.post_through_setup(cls)
@@ -78,7 +102,16 @@ class TaggableManager(RelatedField):
else:
self.post_through_setup(cls)
+ def __lt__(self, other):
+ """
+ Required contribute_to_class as Django uses bisect
+ for ordered class contribution and bisect requires
+ a orderable type in py3.
+ """
+ return False
+
def post_through_setup(self, cls):
+ self.related = RelatedObject(cls, self.model, self)
self.use_gfk = (
self.through is None or issubclass(self.through, GenericTaggedItemBase)
)
@@ -106,11 +139,14 @@ class TaggableManager(RelatedField):
return self.through.objects.none()
def related_query_name(self):
- return self.model._meta.module_name
+ return _model_name(self.model)
def m2m_reverse_name(self):
return self.through._meta.get_field_by_name("tag")[0].column
+ def m2m_reverse_field_name(self):
+ return self.through._meta.get_field_by_name("tag")[0].name
+
def m2m_target_field_name(self):
return self.model._meta.pk.name
@@ -128,17 +164,148 @@ class TaggableManager(RelatedField):
def m2m_db_table(self):
return self.through._meta.db_table
+ def bulk_related_objects(self, new_objs, using):
+ return []
+
def extra_filters(self, pieces, pos, negate):
if negate or not self.use_gfk:
return []
prefix = "__".join(["tagged_items"] + pieces[:pos-2])
- cts = map(ContentType.objects.get_for_model, _get_subclasses(self.model))
+ get = ContentType.objects.get_for_model
+ cts = [get(obj) for obj in _get_subclasses(self.model)]
if len(cts) == 1:
return [("%s__content_type" % prefix, cts[0])]
return [("%s__content_type__in" % prefix, cts)]
- def bulk_related_objects(self, new_objs, using):
- return []
+ def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
+ model_name = _model_name(self.through)
+ if rhs_alias == '%s_%s' % (self.through._meta.app_label, model_name):
+ alias_to_join = rhs_alias
+ else:
+ alias_to_join = lhs_alias
+ extra_col = self.through._meta.get_field_by_name('content_type')[0].column
+ content_type_ids = [ContentType.objects.get_for_model(subclass).pk for subclass in _get_subclasses(self.model)]
+ if len(content_type_ids) == 1:
+ content_type_id = content_type_ids[0]
+ extra_where = " AND %s.%s = %%s" % (qn(alias_to_join), qn(extra_col))
+ params = [content_type_id]
+ else:
+ extra_where = " AND %s.%s IN (%s)" % (qn(alias_to_join), qn(extra_col), ','.join(['%s']*len(content_type_ids)))
+ params = content_type_ids
+ return extra_where, params
+
+ def _get_mm_case_path_info(self, direct=False):
+ pathinfos = []
+ linkfield1 = self.through._meta.get_field_by_name('content_object')[0]
+ linkfield2 = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
+ if direct:
+ join1infos, _, _, _ = linkfield1.get_reverse_path_info()
+ join2infos, opts, target, final = linkfield2.get_path_info()
+ else:
+ join1infos, _, _, _ = linkfield2.get_reverse_path_info()
+ join2infos, opts, target, final = linkfield1.get_path_info()
+ pathinfos.extend(join1infos)
+ pathinfos.extend(join2infos)
+ return pathinfos, opts, target, final
+
+ def _get_gfk_case_path_info(self, direct=False):
+ pathinfos = []
+ from_field = self.model._meta.pk
+ opts = self.through._meta
+ object_id_field = opts.get_field_by_name('object_id')[0]
+ linkfield = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
+ if direct:
+ join1infos = [PathInfo(from_field, object_id_field, self.model._meta, opts, self, True, False)]
+ join2infos, opts, target, final = linkfield.get_path_info()
+ else:
+ join1infos, _, _, _ = linkfield.get_reverse_path_info()
+ join2infos = [PathInfo(object_id_field, from_field, opts, self.model._meta, self, True, False)]
+ target = from_field
+ final = self
+ opts = self.model._meta
+
+ pathinfos.extend(join1infos)
+ pathinfos.extend(join2infos)
+ return pathinfos, opts, target, final
+
+ def get_path_info(self):
+ if self.use_gfk:
+ return self._get_gfk_case_path_info(direct=True)
+ else:
+ return self._get_mm_case_path_info(direct=True)
+
+ def get_reverse_path_info(self):
+ if self.use_gfk:
+ return self._get_gfk_case_path_info(direct=False)
+ else:
+ return self._get_mm_case_path_info(direct=False)
+
+ # This and all the methods till the end of class are only used in django >= 1.6
+ def _get_mm_case_path_info(self, direct=False):
+ pathinfos = []
+ linkfield1 = self.through._meta.get_field_by_name('content_object')[0]
+ linkfield2 = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
+ if direct:
+ join1infos = linkfield1.get_reverse_path_info()
+ join2infos = linkfield2.get_path_info()
+ else:
+ join1infos = linkfield2.get_reverse_path_info()
+ join2infos = linkfield1.get_path_info()
+ pathinfos.extend(join1infos)
+ pathinfos.extend(join2infos)
+ return pathinfos
+
+ def _get_gfk_case_path_info(self, direct=False):
+ pathinfos = []
+ from_field = self.model._meta.pk
+ opts = self.through._meta
+ object_id_field = opts.get_field_by_name('object_id')[0]
+ linkfield = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
+ if direct:
+ join1infos = [PathInfo(self.model._meta, opts, [from_field], self.rel, True, False)]
+ join2infos = linkfield.get_path_info()
+ else:
+ join1infos = linkfield.get_reverse_path_info()
+ join2infos = [PathInfo(opts, self.model._meta, [object_id_field], self, True, False)]
+ pathinfos.extend(join1infos)
+ pathinfos.extend(join2infos)
+ return pathinfos
+
+ def get_path_info(self):
+ if self.use_gfk:
+ return self._get_gfk_case_path_info(direct=True)
+ else:
+ return self._get_mm_case_path_info(direct=True)
+
+ def get_reverse_path_info(self):
+ if self.use_gfk:
+ return self._get_gfk_case_path_info(direct=False)
+ else:
+ return self._get_mm_case_path_info(direct=False)
+
+ def get_joining_columns(self, reverse_join=False):
+ if reverse_join:
+ return (("id", "object_id"),)
+ else:
+ return (("object_id", "id"),)
+
+ def get_extra_restriction(self, where_class, alias, related_alias):
+ extra_col = self.through._meta.get_field_by_name('content_type')[0].column
+ content_type_ids = [ContentType.objects.get_for_model(subclass).pk
+ for subclass in _get_subclasses(self.model)]
+ return ExtraJoinRestriction(related_alias, extra_col, content_type_ids)
+
+ def get_reverse_joining_columns(self):
+ return self.get_joining_columns(reverse_join=True)
+
+ @property
+ def related_fields(self):
+ return [(self.through._meta.get_field_by_name('object_id')[0],
+ self.model._meta.pk)]
+
+ @property
+ def foreign_related_fields(self):
+ return [self.related_fields[0][1]]
class _TaggableManager(models.Manager):
@@ -150,6 +317,9 @@ class _TaggableManager(models.Manager):
def get_query_set(self):
return self.through.tags_for(self.model, self.instance)
+ # Django 1.6 renamed this
+ get_queryset = get_query_set
+
def _lookup_kwargs(self):
return self.through.lookup_kwargs(self.instance)
@@ -175,6 +345,14 @@ class _TaggableManager(models.Manager):
self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs())
@require_instance_manager
+ def names(self):
+ return self.get_query_set().values_list('name', flat=True)
+
+ @require_instance_manager
+ def slugs(self):
+ return self.get_query_set().values_list('slug', flat=True)
+
+ @require_instance_manager
def set(self, *tags):
self.clear()
self.add(*tags)
@@ -197,7 +375,7 @@ class _TaggableManager(models.Manager):
def similar_objects(self):
lookup_kwargs = self._lookup_kwargs()
lookup_keys = sorted(lookup_kwargs)
- qs = self.through.objects.values(*lookup_kwargs.keys())
+ qs = self.through.objects.values(*six.iterkeys(lookup_kwargs))
qs = qs.annotate(n=models.Count('pk'))
qs = qs.exclude(**lookup_kwargs)
qs = qs.filter(tag__in=self.all())
@@ -220,7 +398,7 @@ class _TaggableManager(models.Manager):
preload.setdefault(result['content_type'], set())
preload[result["content_type"]].add(result["object_id"])
- for ct, obj_ids in preload.iteritems():
+ for ct, obj_ids in preload.items():
ct = ContentType.objects.get_for_id(ct)
for obj in ct.model_class()._default_manager.filter(pk__in=obj_ids):
items[(ct.pk, obj.pk)] = obj
@@ -243,3 +421,11 @@ def _get_subclasses(model):
getattr(field.field.rel, "parent_link", None)):
subclasses.extend(_get_subclasses(field.model))
return subclasses
+
+
+# `total_ordering` does not exist in Django 1.4, as such
+# we special case this import to be py3k specific which
+# is not supported by Django 1.4
+if six.PY3:
+ from django.utils.functional import total_ordering
+ TaggableManager = total_ordering(TaggableManager)
diff --git a/awx/lib/site-packages/taggit/migrations/0001_initial.py b/awx/lib/site-packages/taggit/migrations/0001_initial.py
index 666ca6199e..6808f38bca 100644
--- a/awx/lib/site-packages/taggit/migrations/0001_initial.py
+++ b/awx/lib/site-packages/taggit/migrations/0001_initial.py
@@ -9,52 +9,52 @@ class Migration(SchemaMigration):
def forwards(self, orm):
# Adding model 'Tag'
- db.create_table(u'taggit_tag', (
- (u'id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
+ db.create_table('taggit_tag', (
+ ('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
('name', self.gf('django.db.models.fields.CharField')(max_length=100)),
('slug', self.gf('django.db.models.fields.SlugField')(unique=True, max_length=100)),
))
- db.send_create_signal(u'taggit', ['Tag'])
+ db.send_create_signal('taggit', ['Tag'])
# Adding model 'TaggedItem'
- db.create_table(u'taggit_taggeditem', (
- (u'id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
- ('tag', self.gf('django.db.models.fields.related.ForeignKey')(related_name=u'taggit_taggeditem_items', to=orm['taggit.Tag'])),
+ db.create_table('taggit_taggeditem', (
+ ('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
+ ('tag', self.gf('django.db.models.fields.related.ForeignKey')(related_name='taggit_taggeditem_items', to=orm['taggit.Tag'])),
('object_id', self.gf('django.db.models.fields.IntegerField')(db_index=True)),
- ('content_type', self.gf('django.db.models.fields.related.ForeignKey')(related_name=u'taggit_taggeditem_tagged_items', to=orm['contenttypes.ContentType'])),
+ ('content_type', self.gf('django.db.models.fields.related.ForeignKey')(related_name='taggit_taggeditem_tagged_items', to=orm['contenttypes.ContentType'])),
))
- db.send_create_signal(u'taggit', ['TaggedItem'])
+ db.send_create_signal('taggit', ['TaggedItem'])
def backwards(self, orm):
# Deleting model 'Tag'
- db.delete_table(u'taggit_tag')
+ db.delete_table('taggit_tag')
# Deleting model 'TaggedItem'
- db.delete_table(u'taggit_taggeditem')
+ db.delete_table('taggit_taggeditem')
models = {
- u'contenttypes.contenttype': {
+ 'contenttypes.contenttype': {
'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
- u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '100'})
},
- u'taggit.tag': {
+ 'taggit.tag': {
'Meta': {'object_name': 'Tag'},
- u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'slug': ('django.db.models.fields.SlugField', [], {'unique': 'True', 'max_length': '100'})
},
- u'taggit.taggeditem': {
+ 'taggit.taggeditem': {
'Meta': {'object_name': 'TaggedItem'},
- 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_tagged_items'", 'to': u"orm['contenttypes.ContentType']"}),
- u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_tagged_items'", 'to': "orm['contenttypes.ContentType']"}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'object_id': ('django.db.models.fields.IntegerField', [], {'db_index': 'True'}),
- 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_items'", 'to': u"orm['taggit.Tag']"})
+ 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_items'", 'to': "orm['taggit.Tag']"})
}
}
- complete_apps = ['taggit'] \ No newline at end of file
+ complete_apps = ['taggit']
diff --git a/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py b/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py
index e5eb033b0a..d68ea10164 100644
--- a/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py
+++ b/awx/lib/site-packages/taggit/migrations/0002_unique_tagnames.py
@@ -9,35 +9,35 @@ class Migration(SchemaMigration):
def forwards(self, orm):
# Adding unique constraint on 'Tag', fields ['name']
- db.create_unique(u'taggit_tag', ['name'])
+ db.create_unique('taggit_tag', ['name'])
def backwards(self, orm):
# Removing unique constraint on 'Tag', fields ['name']
- db.delete_unique(u'taggit_tag', ['name'])
+ db.delete_unique('taggit_tag', ['name'])
models = {
- u'contenttypes.contenttype': {
+ 'contenttypes.contenttype': {
'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
- u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '100'})
},
- u'taggit.tag': {
+ 'taggit.tag': {
'Meta': {'object_name': 'Tag'},
- u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '100'}),
'slug': ('django.db.models.fields.SlugField', [], {'unique': 'True', 'max_length': '100'})
},
- u'taggit.taggeditem': {
+ 'taggit.taggeditem': {
'Meta': {'object_name': 'TaggedItem'},
- 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_tagged_items'", 'to': u"orm['contenttypes.ContentType']"}),
- u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_tagged_items'", 'to': "orm['contenttypes.ContentType']"}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'object_id': ('django.db.models.fields.IntegerField', [], {'db_index': 'True'}),
- 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "u'taggit_taggeditem_items'", 'to': u"orm['taggit.Tag']"})
+ 'tag': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'taggit_taggeditem_items'", 'to': "orm['taggit.Tag']"})
}
}
- complete_apps = ['taggit'] \ No newline at end of file
+ complete_apps = ['taggit']
diff --git a/awx/lib/site-packages/taggit/models.py b/awx/lib/site-packages/taggit/models.py
index 581a5b1920..0335757875 100644
--- a/awx/lib/site-packages/taggit/models.py
+++ b/awx/lib/site-packages/taggit/models.py
@@ -1,16 +1,21 @@
+from __future__ import unicode_literals
+
import django
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericForeignKey
from django.db import models, IntegrityError, transaction
+from django.db.models.query import QuerySet
from django.template.defaultfilters import slugify as default_slugify
from django.utils.translation import ugettext_lazy as _, ugettext
+from django.utils.encoding import python_2_unicode_compatible
+@python_2_unicode_compatible
class TagBase(models.Model):
name = models.CharField(verbose_name=_('Name'), unique=True, max_length=100)
slug = models.SlugField(verbose_name=_('Slug'), unique=True, max_length=100)
- def __unicode__(self):
+ def __str__(self):
return self.name
class Meta:
@@ -19,17 +24,14 @@ class TagBase(models.Model):
def save(self, *args, **kwargs):
if not self.pk and not self.slug:
self.slug = self.slugify(self.name)
- if django.VERSION >= (1, 2):
- from django.db import router
- using = kwargs.get("using") or router.db_for_write(
- type(self), instance=self)
- # Make sure we write to the same db for all attempted writes,
- # with a multi-master setup, theoretically we could try to
- # write and rollback on different DBs
- kwargs["using"] = using
- trans_kwargs = {"using": using}
- else:
- trans_kwargs = {}
+ from django.db import router
+ using = kwargs.get("using") or router.db_for_write(
+ type(self), instance=self)
+ # Make sure we write to the same db for all attempted writes,
+ # with a multi-master setup, theoretically we could try to
+ # write and rollback on different DBs
+ kwargs["using"] = using
+ trans_kwargs = {"using": using}
i = 0
while True:
i += 1
@@ -57,9 +59,9 @@ class Tag(TagBase):
verbose_name_plural = _("Tags")
-
+@python_2_unicode_compatible
class ItemBase(models.Model):
- def __unicode__(self):
+ def __str__(self):
return ugettext("%(object)s tagged with %(tag)s") % {
"object": self.content_object,
"tag": self.tag
@@ -90,10 +92,7 @@ class ItemBase(models.Model):
class TaggedItemBase(ItemBase):
- if django.VERSION < (1, 2):
- tag = models.ForeignKey(Tag, related_name="%(class)s_items")
- else:
- tag = models.ForeignKey(Tag, related_name="%(app_label)s_%(class)s_items")
+ tag = models.ForeignKey(Tag, related_name="%(app_label)s_%(class)s_items")
class Meta:
abstract = True
@@ -111,18 +110,11 @@ class TaggedItemBase(ItemBase):
class GenericTaggedItemBase(ItemBase):
object_id = models.IntegerField(verbose_name=_('Object id'), db_index=True)
- if django.VERSION < (1, 2):
- content_type = models.ForeignKey(
- ContentType,
- verbose_name=_('Content type'),
- related_name="%(class)s_tagged_items"
- )
- else:
- content_type = models.ForeignKey(
- ContentType,
- verbose_name=_('Content type'),
- related_name="%(app_label)s_%(class)s_tagged_items"
- )
+ content_type = models.ForeignKey(
+ ContentType,
+ verbose_name=_('Content type'),
+ related_name="%(app_label)s_%(class)s_tagged_items"
+ )
content_object = GenericForeignKey()
class Meta:
@@ -137,11 +129,18 @@ class GenericTaggedItemBase(ItemBase):
@classmethod
def bulk_lookup_kwargs(cls, instances):
- # TODO: instances[0], can we assume there are instances.
- return {
- "object_id__in": [instance.pk for instance in instances],
- "content_type": ContentType.objects.get_for_model(instances[0]),
- }
+ if isinstance(instances, QuerySet):
+ # Can do a real object_id IN (SELECT ..) query.
+ return {
+ "object_id__in": instances,
+ "content_type": ContentType.objects.get_for_model(instances.model),
+ }
+ else:
+ # TODO: instances[0], can we assume there are instances.
+ return {
+ "object_id__in": [instance.pk for instance in instances],
+ "content_type": ContentType.objects.get_for_model(instances[0]),
+ }
@classmethod
def tags_for(cls, model, instance=None):
diff --git a/awx/lib/site-packages/taggit/tests/__init__.py b/awx/lib/site-packages/taggit/tests/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/awx/lib/site-packages/taggit/tests/__init__.py
+++ /dev/null
diff --git a/awx/lib/site-packages/taggit/tests/forms.py b/awx/lib/site-packages/taggit/tests/forms.py
deleted file mode 100644
index 2cdc6a8dca..0000000000
--- a/awx/lib/site-packages/taggit/tests/forms.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from django import forms
-
-from taggit.tests.models import Food, DirectFood, CustomPKFood, OfficialFood
-
-
-class FoodForm(forms.ModelForm):
- class Meta:
- model = Food
-
-class DirectFoodForm(forms.ModelForm):
- class Meta:
- model = DirectFood
-
-class CustomPKFoodForm(forms.ModelForm):
- class Meta:
- model = CustomPKFood
-
-class OfficialFoodForm(forms.ModelForm):
- class Meta:
- model = OfficialFood
diff --git a/awx/lib/site-packages/taggit/tests/models.py b/awx/lib/site-packages/taggit/tests/models.py
deleted file mode 100644
index a0e21e046f..0000000000
--- a/awx/lib/site-packages/taggit/tests/models.py
+++ /dev/null
@@ -1,143 +0,0 @@
-from django.db import models
-
-from taggit.managers import TaggableManager
-from taggit.models import (TaggedItemBase, GenericTaggedItemBase, TaggedItem,
- TagBase, Tag)
-
-
-class Food(models.Model):
- name = models.CharField(max_length=50)
-
- tags = TaggableManager()
-
- def __unicode__(self):
- return self.name
-
-class Pet(models.Model):
- name = models.CharField(max_length=50)
-
- tags = TaggableManager()
-
- def __unicode__(self):
- return self.name
-
-class HousePet(Pet):
- trained = models.BooleanField()
-
-
-# Test direct-tagging with custom through model
-
-class TaggedFood(TaggedItemBase):
- content_object = models.ForeignKey('DirectFood')
-
-class TaggedPet(TaggedItemBase):
- content_object = models.ForeignKey('DirectPet')
-
-class DirectFood(models.Model):
- name = models.CharField(max_length=50)
-
- tags = TaggableManager(through="TaggedFood")
-
-class DirectPet(models.Model):
- name = models.CharField(max_length=50)
-
- tags = TaggableManager(through=TaggedPet)
-
- def __unicode__(self):
- return self.name
-
-class DirectHousePet(DirectPet):
- trained = models.BooleanField()
-
-
-# Test custom through model to model with custom PK
-
-class TaggedCustomPKFood(TaggedItemBase):
- content_object = models.ForeignKey('CustomPKFood')
-
-class TaggedCustomPKPet(TaggedItemBase):
- content_object = models.ForeignKey('CustomPKPet')
-
-class CustomPKFood(models.Model):
- name = models.CharField(max_length=50, primary_key=True)
-
- tags = TaggableManager(through=TaggedCustomPKFood)
-
- def __unicode__(self):
- return self.name
-
-class CustomPKPet(models.Model):
- name = models.CharField(max_length=50, primary_key=True)
-
- tags = TaggableManager(through=TaggedCustomPKPet)
-
- def __unicode__(self):
- return self.name
-
-class CustomPKHousePet(CustomPKPet):
- trained = models.BooleanField()
-
-# Test custom through model to a custom tag model
-
-class OfficialTag(TagBase):
- official = models.BooleanField()
-
-class OfficialThroughModel(GenericTaggedItemBase):
- tag = models.ForeignKey(OfficialTag, related_name="tagged_items")
-
-class OfficialFood(models.Model):
- name = models.CharField(max_length=50)
-
- tags = TaggableManager(through=OfficialThroughModel)
-
- def __unicode__(self):
- return self.name
-
-class OfficialPet(models.Model):
- name = models.CharField(max_length=50)
-
- tags = TaggableManager(through=OfficialThroughModel)
-
- def __unicode__(self):
- return self.name
-
-class OfficialHousePet(OfficialPet):
- trained = models.BooleanField()
-
-
-class Media(models.Model):
- tags = TaggableManager()
-
- class Meta:
- abstract = True
-
-class Photo(Media):
- pass
-
-class Movie(Media):
- pass
-
-
-class ArticleTag(Tag):
- class Meta:
- proxy = True
-
- def slugify(self, tag, i=None):
- slug = "category-%s" % tag.lower()
-
- if i is not None:
- slug += "-%d" % i
- return slug
-
-class ArticleTaggedItem(TaggedItem):
- class Meta:
- proxy = True
-
- @classmethod
- def tag_model(self):
- return ArticleTag
-
-class Article(models.Model):
- title = models.CharField(max_length=100)
-
- tags = TaggableManager(through=ArticleTaggedItem)
diff --git a/awx/lib/site-packages/taggit/tests/runtests.py b/awx/lib/site-packages/taggit/tests/runtests.py
deleted file mode 100644
index 3e52cf18fe..0000000000
--- a/awx/lib/site-packages/taggit/tests/runtests.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#!/usr/bin/env python
-import os
-import sys
-
-from django.conf import settings
-
-if not settings.configured:
- settings.configure(
- DATABASES={
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3',
- }
- },
- INSTALLED_APPS=[
- 'django.contrib.contenttypes',
- 'taggit',
- 'taggit.tests',
- ]
- )
-
-
-from django.test.simple import DjangoTestSuiteRunner
-
-def runtests():
- runner = DjangoTestSuiteRunner()
- failures = runner.run_tests(['tests'], verbosity=1, interactive=True)
- sys.exit(failures)
-
-if __name__ == '__main__':
- runtests(*sys.argv[1:])
-
diff --git a/awx/lib/site-packages/taggit/tests/tests.py b/awx/lib/site-packages/taggit/tests/tests.py
deleted file mode 100644
index 6282db84ab..0000000000
--- a/awx/lib/site-packages/taggit/tests/tests.py
+++ /dev/null
@@ -1,486 +0,0 @@
-from unittest import TestCase as UnitTestCase
-
-import django
-from django.conf import settings
-from django.core.exceptions import ValidationError
-from django.db import connection
-from django.test import TestCase, TransactionTestCase
-
-from taggit.managers import TaggableManager
-from taggit.models import Tag, TaggedItem
-from taggit.tests.forms import (FoodForm, DirectFoodForm, CustomPKFoodForm,
- OfficialFoodForm)
-from taggit.tests.models import (Food, Pet, HousePet, DirectFood, DirectPet,
- DirectHousePet, TaggedPet, CustomPKFood, CustomPKPet, CustomPKHousePet,
- TaggedCustomPKPet, OfficialFood, OfficialPet, OfficialHousePet,
- OfficialThroughModel, OfficialTag, Photo, Movie, Article)
-from taggit.utils import parse_tags, edit_string_for_tags
-
-
-class BaseTaggingTest(object):
- def assert_tags_equal(self, qs, tags, sort=True, attr="name"):
- got = map(lambda tag: getattr(tag, attr), qs)
- if sort:
- got.sort()
- tags.sort()
- self.assertEqual(got, tags)
-
- def assert_num_queries(self, n, f, *args, **kwargs):
- original_DEBUG = settings.DEBUG
- settings.DEBUG = True
- current = len(connection.queries)
- try:
- f(*args, **kwargs)
- self.assertEqual(
- len(connection.queries) - current,
- n,
- )
- finally:
- settings.DEBUG = original_DEBUG
-
- def _get_form_str(self, form_str):
- if django.VERSION >= (1, 3):
- form_str %= {
- "help_start": '<span class="helptext">',
- "help_stop": "</span>"
- }
- else:
- form_str %= {
- "help_start": "",
- "help_stop": ""
- }
- return form_str
-
- def assert_form_renders(self, form, html):
- try:
- self.assertHTMLEqual(str(form), self._get_form_str(html))
- except AttributeError:
- self.assertEqual(str(form), self._get_form_str(html))
-
-
-class BaseTaggingTestCase(TestCase, BaseTaggingTest):
- pass
-
-class BaseTaggingTransactionTestCase(TransactionTestCase, BaseTaggingTest):
- pass
-
-
-class TagModelTestCase(BaseTaggingTransactionTestCase):
- food_model = Food
- tag_model = Tag
-
- def test_unique_slug(self):
- apple = self.food_model.objects.create(name="apple")
- apple.tags.add("Red", "red")
-
- def test_update(self):
- special = self.tag_model.objects.create(name="special")
- special.save()
-
- def test_add(self):
- apple = self.food_model.objects.create(name="apple")
- yummy = self.tag_model.objects.create(name="yummy")
- apple.tags.add(yummy)
-
- def test_slugify(self):
- a = Article.objects.create(title="django-taggit 1.0 Released")
- a.tags.add("awesome", "release", "AWESOME")
- self.assert_tags_equal(a.tags.all(), [
- "category-awesome",
- "category-release",
- "category-awesome-1"
- ], attr="slug")
-
-class TagModelDirectTestCase(TagModelTestCase):
- food_model = DirectFood
- tag_model = Tag
-
-class TagModelCustomPKTestCase(TagModelTestCase):
- food_model = CustomPKFood
- tag_model = Tag
-
-class TagModelOfficialTestCase(TagModelTestCase):
- food_model = OfficialFood
- tag_model = OfficialTag
-
-class TaggableManagerTestCase(BaseTaggingTestCase):
- food_model = Food
- pet_model = Pet
- housepet_model = HousePet
- taggeditem_model = TaggedItem
- tag_model = Tag
-
- def test_add_tag(self):
- apple = self.food_model.objects.create(name="apple")
- self.assertEqual(list(apple.tags.all()), [])
- self.assertEqual(list(self.food_model.tags.all()), [])
-
- apple.tags.add('green')
- self.assert_tags_equal(apple.tags.all(), ['green'])
- self.assert_tags_equal(self.food_model.tags.all(), ['green'])
-
- pear = self.food_model.objects.create(name="pear")
- pear.tags.add('green')
- self.assert_tags_equal(pear.tags.all(), ['green'])
- self.assert_tags_equal(self.food_model.tags.all(), ['green'])
-
- apple.tags.add('red')
- self.assert_tags_equal(apple.tags.all(), ['green', 'red'])
- self.assert_tags_equal(self.food_model.tags.all(), ['green', 'red'])
-
- self.assert_tags_equal(
- self.food_model.tags.most_common(),
- ['green', 'red'],
- sort=False
- )
-
- apple.tags.remove('green')
- self.assert_tags_equal(apple.tags.all(), ['red'])
- self.assert_tags_equal(self.food_model.tags.all(), ['green', 'red'])
- tag = self.tag_model.objects.create(name="delicious")
- apple.tags.add(tag)
- self.assert_tags_equal(apple.tags.all(), ["red", "delicious"])
-
- apple.delete()
- self.assert_tags_equal(self.food_model.tags.all(), ["green"])
-
- def test_add_queries(self):
- apple = self.food_model.objects.create(name="apple")
- # 1 query to see which tags exist
- # + 3 queries to create the tags.
- # + 6 queries to create the intermediary things (including SELECTs, to
- # make sure we don't double create.
- self.assert_num_queries(10, apple.tags.add, "red", "delicious", "green")
-
- pear = self.food_model.objects.create(name="pear")
- # 1 query to see which tags exist
- # + 4 queries to create the intermeidary things (including SELECTs, to
- # make sure we dont't double create.
- self.assert_num_queries(5, pear.tags.add, "green", "delicious")
-
- self.assert_num_queries(0, pear.tags.add)
-
- def test_require_pk(self):
- food_instance = self.food_model()
- self.assertRaises(ValueError, lambda: food_instance.tags.all())
-
- def test_delete_obj(self):
- apple = self.food_model.objects.create(name="apple")
- apple.tags.add("red")
- self.assert_tags_equal(apple.tags.all(), ["red"])
- strawberry = self.food_model.objects.create(name="strawberry")
- strawberry.tags.add("red")
- apple.delete()
- self.assert_tags_equal(strawberry.tags.all(), ["red"])
-
- def test_delete_bulk(self):
- apple = self.food_model.objects.create(name="apple")
- kitty = self.pet_model.objects.create(pk=apple.pk, name="kitty")
-
- apple.tags.add("red", "delicious", "fruit")
- kitty.tags.add("feline")
-
- self.food_model.objects.all().delete()
-
- self.assert_tags_equal(kitty.tags.all(), ["feline"])
-
- def test_lookup_by_tag(self):
- apple = self.food_model.objects.create(name="apple")
- apple.tags.add("red", "green")
- pear = self.food_model.objects.create(name="pear")
- pear.tags.add("green")
-
- self.assertEqual(
- list(self.food_model.objects.filter(tags__name__in=["red"])),
- [apple]
- )
- self.assertEqual(
- list(self.food_model.objects.filter(tags__name__in=["green"])),
- [apple, pear]
- )
-
- kitty = self.pet_model.objects.create(name="kitty")
- kitty.tags.add("fuzzy", "red")
- dog = self.pet_model.objects.create(name="dog")
- dog.tags.add("woof", "red")
- self.assertEqual(
- list(self.food_model.objects.filter(tags__name__in=["red"]).distinct()),
- [apple]
- )
-
- tag = self.tag_model.objects.get(name="woof")
- self.assertEqual(list(self.pet_model.objects.filter(tags__in=[tag])), [dog])
-
- cat = self.housepet_model.objects.create(name="cat", trained=True)
- cat.tags.add("fuzzy")
-
- self.assertEqual(
- map(lambda o: o.pk, self.pet_model.objects.filter(tags__name__in=["fuzzy"])),
- [kitty.pk, cat.pk]
- )
-
- def test_exclude(self):
- apple = self.food_model.objects.create(name="apple")
- apple.tags.add("red", "green", "delicious")
-
- pear = self.food_model.objects.create(name="pear")
- pear.tags.add("green", "delicious")
-
- guava = self.food_model.objects.create(name="guava")
-
- self.assertEqual(
- sorted(map(lambda o: o.pk, self.food_model.objects.exclude(tags__name__in=["red"]))),
- sorted([pear.pk, guava.pk]),
- )
-
- def test_similarity_by_tag(self):
- """Test that pears are more similar to apples than watermelons"""
- apple = self.food_model.objects.create(name="apple")
- apple.tags.add("green", "juicy", "small", "sour")
-
- pear = self.food_model.objects.create(name="pear")
- pear.tags.add("green", "juicy", "small", "sweet")
-
- watermelon = self.food_model.objects.create(name="watermelon")
- watermelon.tags.add("green", "juicy", "large", "sweet")
-
- similar_objs = apple.tags.similar_objects()
- self.assertEqual(similar_objs, [pear, watermelon])
- self.assertEqual(map(lambda x: x.similar_tags, similar_objs), [3, 2])
-
- def test_tag_reuse(self):
- apple = self.food_model.objects.create(name="apple")
- apple.tags.add("juicy", "juicy")
- self.assert_tags_equal(apple.tags.all(), ['juicy'])
-
- def test_query_traverse(self):
- spot = self.pet_model.objects.create(name='Spot')
- spike = self.pet_model.objects.create(name='Spike')
- spot.tags.add('scary')
- spike.tags.add('fluffy')
- lookup_kwargs = {
- '%s__name' % self.pet_model._meta.module_name: 'Spot'
- }
- self.assert_tags_equal(
- self.tag_model.objects.filter(**lookup_kwargs),
- ['scary']
- )
-
- def test_taggeditem_unicode(self):
- ross = self.pet_model.objects.create(name="ross")
- # I keep Ross Perot for a pet, what's it to you?
- ross.tags.add("president")
-
- self.assertEqual(
- unicode(self.taggeditem_model.objects.all()[0]),
- "ross tagged with president"
- )
-
- def test_abstract_subclasses(self):
- p = Photo.objects.create()
- p.tags.add("outdoors", "pretty")
- self.assert_tags_equal(
- p.tags.all(),
- ["outdoors", "pretty"]
- )
-
- m = Movie.objects.create()
- m.tags.add("hd")
- self.assert_tags_equal(
- m.tags.all(),
- ["hd"],
- )
-
- def test_field_api(self):
- # Check if tag field, which simulates m2m, has django-like api.
- field = self.food_model._meta.get_field('tags')
- self.assertTrue(hasattr(field, 'rel'))
- self.assertTrue(hasattr(field, 'related'))
- self.assertEqual(self.food_model, field.related.model)
-
-
-class TaggableManagerDirectTestCase(TaggableManagerTestCase):
- food_model = DirectFood
- pet_model = DirectPet
- housepet_model = DirectHousePet
- taggeditem_model = TaggedPet
-
-class TaggableManagerCustomPKTestCase(TaggableManagerTestCase):
- food_model = CustomPKFood
- pet_model = CustomPKPet
- housepet_model = CustomPKHousePet
- taggeditem_model = TaggedCustomPKPet
-
- def test_require_pk(self):
- # TODO with a charfield pk, pk is never None, so taggit has no way to
- # tell if the instance is saved or not
- pass
-
-class TaggableManagerOfficialTestCase(TaggableManagerTestCase):
- food_model = OfficialFood
- pet_model = OfficialPet
- housepet_model = OfficialHousePet
- taggeditem_model = OfficialThroughModel
- tag_model = OfficialTag
-
- def test_extra_fields(self):
- self.tag_model.objects.create(name="red")
- self.tag_model.objects.create(name="delicious", official=True)
- apple = self.food_model.objects.create(name="apple")
- apple.tags.add("delicious", "red")
-
- pear = self.food_model.objects.create(name="Pear")
- pear.tags.add("delicious")
-
- self.assertEqual(
- map(lambda o: o.pk, self.food_model.objects.filter(tags__official=False)),
- [apple.pk],
- )
-
-
-class TaggableFormTestCase(BaseTaggingTestCase):
- form_class = FoodForm
- food_model = Food
-
- def test_form(self):
- self.assertEqual(self.form_class.base_fields.keys(), ['name', 'tags'])
-
- f = self.form_class({'name': 'apple', 'tags': 'green, red, yummy'})
- self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr>
-<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value="green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""")
- f.save()
- apple = self.food_model.objects.get(name='apple')
- self.assert_tags_equal(apple.tags.all(), ['green', 'red', 'yummy'])
-
- f = self.form_class({'name': 'apple', 'tags': 'green, red, yummy, delicious'}, instance=apple)
- f.save()
- apple = self.food_model.objects.get(name='apple')
- self.assert_tags_equal(apple.tags.all(), ['green', 'red', 'yummy', 'delicious'])
- self.assertEqual(self.food_model.objects.count(), 1)
-
- f = self.form_class({"name": "raspberry"})
- self.assertFalse(f.is_valid())
-
- f = self.form_class(instance=apple)
- self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr>
-<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value="delicious, green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""")
-
- apple.tags.add('has,comma')
- f = self.form_class(instance=apple)
- self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr>
-<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value="&quot;has,comma&quot;, delicious, green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""")
-
- apple.tags.add('has space')
- f = self.form_class(instance=apple)
- self.assert_form_renders(f, """<tr><th><label for="id_name">Name:</label></th><td><input id="id_name" type="text" name="name" value="apple" maxlength="50" /></td></tr>
-<tr><th><label for="id_tags">Tags:</label></th><td><input type="text" name="tags" value="&quot;has space&quot;, &quot;has,comma&quot;, delicious, green, red, yummy" id="id_tags" /><br />%(help_start)sA comma-separated list of tags.%(help_stop)s</td></tr>""")
-
- def test_formfield(self):
- tm = TaggableManager(verbose_name='categories', help_text='Add some categories', blank=True)
- ff = tm.formfield()
- self.assertEqual(ff.label, 'Categories')
- self.assertEqual(ff.help_text, u'Add some categories')
- self.assertEqual(ff.required, False)
-
- self.assertEqual(ff.clean(""), [])
-
- tm = TaggableManager()
- ff = tm.formfield()
- self.assertRaises(ValidationError, ff.clean, "")
-
-class TaggableFormDirectTestCase(TaggableFormTestCase):
- form_class = DirectFoodForm
- food_model = DirectFood
-
-class TaggableFormCustomPKTestCase(TaggableFormTestCase):
- form_class = CustomPKFoodForm
- food_model = CustomPKFood
-
-class TaggableFormOfficialTestCase(TaggableFormTestCase):
- form_class = OfficialFoodForm
- food_model = OfficialFood
-
-
-class TagStringParseTestCase(UnitTestCase):
- """
- Ported from Jonathan Buchanan's `django-tagging
- <http://django-tagging.googlecode.com/>`_
- """
-
- def test_with_simple_space_delimited_tags(self):
- """
- Test with simple space-delimited tags.
- """
- self.assertEqual(parse_tags('one'), [u'one'])
- self.assertEqual(parse_tags('one two'), [u'one', u'two'])
- self.assertEqual(parse_tags('one two three'), [u'one', u'three', u'two'])
- self.assertEqual(parse_tags('one one two two'), [u'one', u'two'])
-
- def test_with_comma_delimited_multiple_words(self):
- """
- Test with comma-delimited multiple words.
- An unquoted comma in the input will trigger this.
- """
- self.assertEqual(parse_tags(',one'), [u'one'])
- self.assertEqual(parse_tags(',one two'), [u'one two'])
- self.assertEqual(parse_tags(',one two three'), [u'one two three'])
- self.assertEqual(parse_tags('a-one, a-two and a-three'),
- [u'a-one', u'a-two and a-three'])
-
- def test_with_double_quoted_multiple_words(self):
- """
- Test with double-quoted multiple words.
- A completed quote will trigger this. Unclosed quotes are ignored.
- """
- self.assertEqual(parse_tags('"one'), [u'one'])
- self.assertEqual(parse_tags('"one two'), [u'one', u'two'])
- self.assertEqual(parse_tags('"one two three'), [u'one', u'three', u'two'])
- self.assertEqual(parse_tags('"one two"'), [u'one two'])
- self.assertEqual(parse_tags('a-one "a-two and a-three"'),
- [u'a-one', u'a-two and a-three'])
-
- def test_with_no_loose_commas(self):
- """
- Test with no loose commas -- split on spaces.
- """
- self.assertEqual(parse_tags('one two "thr,ee"'), [u'one', u'thr,ee', u'two'])
-
- def test_with_loose_commas(self):
- """
- Loose commas - split on commas
- """
- self.assertEqual(parse_tags('"one", two three'), [u'one', u'two three'])
-
- def test_tags_with_double_quotes_can_contain_commas(self):
- """
- Double quotes can contain commas
- """
- self.assertEqual(parse_tags('a-one "a-two, and a-three"'),
- [u'a-one', u'a-two, and a-three'])
- self.assertEqual(parse_tags('"two", one, one, two, "one"'),
- [u'one', u'two'])
-
- def test_with_naughty_input(self):
- """
- Test with naughty input.
- """
- # Bad users! Naughty users!
- self.assertEqual(parse_tags(None), [])
- self.assertEqual(parse_tags(''), [])
- self.assertEqual(parse_tags('"'), [])
- self.assertEqual(parse_tags('""'), [])
- self.assertEqual(parse_tags('"' * 7), [])
- self.assertEqual(parse_tags(',,,,,,'), [])
- self.assertEqual(parse_tags('",",",",",",","'), [u','])
- self.assertEqual(parse_tags('a-one "a-two" and "a-three'),
- [u'a-one', u'a-three', u'a-two', u'and'])
-
- def test_recreation_of_tag_list_string_representations(self):
- plain = Tag.objects.create(name='plain')
- spaces = Tag.objects.create(name='spa ces')
- comma = Tag.objects.create(name='com,ma')
- self.assertEqual(edit_string_for_tags([plain]), u'plain')
- self.assertEqual(edit_string_for_tags([plain, spaces]), u'"spa ces", plain')
- self.assertEqual(edit_string_for_tags([plain, spaces, comma]), u'"com,ma", "spa ces", plain')
- self.assertEqual(edit_string_for_tags([plain, comma]), u'"com,ma", plain')
- self.assertEqual(edit_string_for_tags([comma, spaces]), u'"com,ma", "spa ces"')
diff --git a/awx/lib/site-packages/taggit/utils.py b/awx/lib/site-packages/taggit/utils.py
index 1b5e5a7f11..997c4f0a1b 100644
--- a/awx/lib/site-packages/taggit/utils.py
+++ b/awx/lib/site-packages/taggit/utils.py
@@ -1,5 +1,8 @@
-from django.utils.encoding import force_unicode
+from __future__ import unicode_literals
+
+from django.utils.encoding import force_text
from django.utils.functional import wraps
+from django.utils import six
def parse_tags(tagstring):
@@ -16,13 +19,13 @@ def parse_tags(tagstring):
if not tagstring:
return []
- tagstring = force_unicode(tagstring)
+ tagstring = force_text(tagstring)
# Special case - if there are no commas or double quotes in the
# input, we don't *do* a recall... I mean, we know we only need to
# split on spaces.
- if u',' not in tagstring and u'"' not in tagstring:
- words = list(set(split_strip(tagstring, u' ')))
+ if ',' not in tagstring and '"' not in tagstring:
+ words = list(set(split_strip(tagstring, ' ')))
words.sort()
return words
@@ -36,39 +39,39 @@ def parse_tags(tagstring):
i = iter(tagstring)
try:
while True:
- c = i.next()
- if c == u'"':
+ c = six.next(i)
+ if c == '"':
if buffer:
- to_be_split.append(u''.join(buffer))
+ to_be_split.append(''.join(buffer))
buffer = []
# Find the matching quote
open_quote = True
- c = i.next()
- while c != u'"':
+ c = six.next(i)
+ while c != '"':
buffer.append(c)
- c = i.next()
+ c = six.next(i)
if buffer:
- word = u''.join(buffer).strip()
+ word = ''.join(buffer).strip()
if word:
words.append(word)
buffer = []
open_quote = False
else:
- if not saw_loose_comma and c == u',':
+ if not saw_loose_comma and c == ',':
saw_loose_comma = True
buffer.append(c)
except StopIteration:
# If we were parsing an open quote which was never closed treat
# the buffer as unquoted.
if buffer:
- if open_quote and u',' in buffer:
+ if open_quote and ',' in buffer:
saw_loose_comma = True
- to_be_split.append(u''.join(buffer))
+ to_be_split.append(''.join(buffer))
if to_be_split:
if saw_loose_comma:
- delimiter = u','
+ delimiter = ','
else:
- delimiter = u' '
+ delimiter = ' '
for chunk in to_be_split:
words.extend(split_strip(chunk, delimiter))
words = list(set(words))
@@ -76,7 +79,7 @@ def parse_tags(tagstring):
return words
-def split_strip(string, delimiter=u','):
+def split_strip(string, delimiter=','):
"""
Splits ``string`` on ``delimiter``, stripping each resulting string
and returning a list of non-empty strings.
@@ -110,11 +113,11 @@ def edit_string_for_tags(tags):
names = []
for tag in tags:
name = tag.name
- if u',' in name or u' ' in name:
+ if ',' in name or ' ' in name:
names.append('"%s"' % name)
else:
names.append(name)
- return u', '.join(sorted(names))
+ return ', '.join(sorted(names))
def require_instance_manager(func):
diff --git a/awx/lib/site-packages/taggit/views.py b/awx/lib/site-packages/taggit/views.py
index 1e407f41c9..68c06dbf9b 100644
--- a/awx/lib/site-packages/taggit/views.py
+++ b/awx/lib/site-packages/taggit/views.py
@@ -1,3 +1,5 @@
+from __future__ import unicode_literals
+
from django.contrib.contenttypes.models import ContentType
from django.shortcuts import get_object_or_404
from django.views.generic.list import ListView