Enumerations for model field choices (django/db/models/enums.py)

Description

Custom enumeration types TextChoices, IntegerChoices, and Choices are now available as a way to define Field.choices.

TextChoices and IntegerChoices types are provided for text and integer fields.

The Choices class allows defining a compatible enumeration for other concrete data types.

These custom enumeration types support human-readable labels that can be translated and accessed via a property on the enumeration or its members.

See Enumeration types for more details and examples.

Enumeration types

In addition, Django provides enumeration types that you can subclass to define choices in a concise way:

from django.utils.translation import gettext_lazy as _

class Student(models.Model):

    class YearInSchool(models.TextChoices):
        FRESHMAN = 'FR', _('Freshman')
        SOPHOMORE = 'SO', _('Sophomore')
        JUNIOR = 'JR', _('Junior')
        SENIOR = 'SR', _('Senior')
        GRADUATE = 'GR', _('Graduate')

    year_in_school = models.CharField(
        max_length=2,
        choices=YearInSchool.choices,
        default=YearInSchool.FRESHMAN,
    )

    def is_upperclass(self):
        return self.year_in_school in {YearInSchool.JUNIOR, YearInSchool.SENIOR}

These work similar to enum from Python’s standard library, but with some modifications:

  • Enum member values are a tuple of arguments to use when constructing the concrete data type. Django supports adding an extra string value to the end of this tuple to be used as the human-readable name, or label . The label can be a lazy translatable string. Thus, in most cases, the member value will be a (value, label) two-tuple. See below for an example of subclassing choices using a more complex data type. If a tuple is not provided, or the last item is not a (lazy) string, the label is automatically generated from the member name.

  • A .label property is added on values, to return the human-readable name.

  • A number of custom properties are added to the enumeration classes – .choices , .labels , .values , and .names – to make it easier to access lists of those separate parts of the enumeration. Use .choices as a suitable value to pass to choices in a field definition.

  • The use of enum.unique() is enforced to ensure that values cannot be defined multiple times. This is unlikely to be expected in choices for a field.

Note that using YearInSchool.SENIOR , YearInSchool['SENIOR'] , or YearInSchool('SR') to access or lookup enum members work as expected, as do the .name and .value properties on the members.

If you don’t need to have the human-readable names translated, you can have them inferred from the member name (replacing underscores with spaces and using title-case):

>>> class Vehicle(models.TextChoices):
...     CAR = 'C'
...     TRUCK = 'T'
...     JET_SKI = 'J'
...
>>> Vehicle.JET_SKI.label
'Jet Ski'

Since the case where the enum values need to be integers is extremely common, Django provides an IntegerChoices class. For example:

class Card(models.Model):

    class Suit(models.IntegerChoices):
        DIAMOND = 1
        SPADE = 2
        HEART = 3
        CLUB = 4

    suit = models.IntegerField(choices=Suit.choices)

It is also possible to make use of the Enum Functional API with the caveat that labels are automatically generated as highlighted above:

>>> MedalType = models.TextChoices('MedalType', 'GOLD SILVER BRONZE')
>>> MedalType.choices
[('GOLD', 'Gold'), ('SILVER', 'Silver'), ('BRONZE', 'Bronze')]
>>> Place = models.IntegerChoices('Place', 'FIRST SECOND THIRD')
>>> Place.choices
[(1, 'First'), (2, 'Second'), (3, 'Third')]

If you require support for a concrete data type other than int or str , you can subclass Choices and the required concrete data type, e.g. date for use with DateField :

class MoonLandings(datetime.date, models.Choices):
    APOLLO_11 = 1969, 7, 20, 'Apollo 11 (Eagle)'
    APOLLO_12 = 1969, 11, 19, 'Apollo 12 (Intrepid)'
    APOLLO_14 = 1971, 2, 5, 'Apollo 14 (Antares)'
    APOLLO_15 = 1971, 7, 30, 'Apollo 15 (Falcon)'
    APOLLO_16 = 1972, 4, 21, 'Apollo 16 (Orion)'
    APOLLO_17 = 1972, 12, 11, 'Apollo 17 (Challenger)'

There are some additional caveats to be aware of:

  • Enumeration types do not support named groups <field-choices-named-groups>.

  • Because an enumeration with a concrete data type requires all values to match the type, overriding the blank label <field-choices-blank-label> cannot be achieved by creating a member with a value of None . Instead, set the __empty__ attribute on the class:

    class Answer(models.IntegerChoices):
        NO = 0, _('No')
        YES = 1, _('Yes')
    
        __empty__ = _('(Unknown)')
    

New in version 3.0: The TextChoices , IntegerChoices , and Choices classes were added.

django/db/models/enums.py

 1import enum
 2
 3from django.utils.functional import Promise
 4
 5__all__ = ["Choices", "IntegerChoices", "TextChoices"]
 6
 7
 8class ChoicesMeta(enum.EnumMeta):
 9    """A metaclass for creating a enum choices."""
10
11    def __new__(metacls, classname, bases, classdict):
12        labels = []
13        for key in classdict._member_names:
14            value = classdict[key]
15            if (
16                isinstance(value, (list, tuple))
17                and len(value) > 1
18                and isinstance(value[-1], (Promise, str))
19            ):
20                *value, label = value
21                value = tuple(value)
22            else:
23                label = key.replace("_", " ").title()
24            labels.append(label)
25            # Use dict.__setitem__() to suppress defenses against double
26            # assignment in enum's classdict.
27            dict.__setitem__(classdict, key, value)
28        cls = super().__new__(metacls, classname, bases, classdict)
29        cls._value2label_map_ = dict(zip(cls._value2member_map_, labels))
30        # Add a label property to instances of enum which uses the enum member
31        # that is passed in as "self" as the value to use when looking up the
32        # label in the choices.
33        cls.label = property(lambda self: cls._value2label_map_.get(self.value))
34        return enum.unique(cls)
35
36    def __contains__(cls, member):
37        if not isinstance(member, enum.Enum):
38            # Allow non-enums to match against member values.
39            return member in {x.value for x in cls}
40        return super().__contains__(member)
41
42    @property
43    def names(cls):
44        empty = ["__empty__"] if hasattr(cls, "__empty__") else []
45        return empty + [member.name for member in cls]
46
47    @property
48    def choices(cls):
49        empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
50        return empty + [(member.value, member.label) for member in cls]
51
52    @property
53    def labels(cls):
54        return [label for _, label in cls.choices]
55
56    @property
57    def values(cls):
58        return [value for value, _ in cls.choices]
59
60
61class Choices(enum.Enum, metaclass=ChoicesMeta):
62    """Class for creating enumerated choices."""
63
64    def __str__(self):
65        """
66        Use value when cast to str, so that Choices set as model instance
67        attributes are rendered as expected in templates and similar contexts.
68        """
69        return str(self.value)
70
71
72class IntegerChoices(int, Choices):
73    """Class for creating enumerated integer choices."""
74
75    pass
76
77
78class TextChoices(str, Choices):
79    """Class for creating enumerated string choices."""
80
81    def _generate_next_value_(name, start, count, last_values):
82        return name

tests/migrations/test_writer.py

  1import datetime
  2import decimal
  3import enum
  4import functools
  5import math
  6import os
  7import re
  8import uuid
  9from unittest import mock
 10
 11import custom_migration_operations.more_operations
 12import custom_migration_operations.operations
 13
 14from django import get_version
 15from django.conf import SettingsReference, settings
 16from django.core.validators import EmailValidator, RegexValidator
 17from django.db import migrations, models
 18from django.db.migrations.serializer import BaseSerializer
 19from django.db.migrations.writer import MigrationWriter, OperationWriter
 20from django.test import SimpleTestCase
 21from django.utils.deconstruct import deconstructible
 22from django.utils.functional import SimpleLazyObject
 23from django.utils.timezone import get_default_timezone, get_fixed_timezone, utc
 24from django.utils.translation import gettext_lazy as _
 25
 26from .models import FoodManager, FoodQuerySet
 27
 28
 29class DeconstructibleInstances:
 30    def deconstruct(self):
 31        return ("DeconstructibleInstances", [], {})
 32
 33
 34class Money(decimal.Decimal):
 35    def deconstruct(self):
 36        return (
 37            "%s.%s" % (self.__class__.__module__, self.__class__.__name__),
 38            [str(self)],
 39            {},
 40        )
 41
 42
 43class TestModel1:
 44    def upload_to(self):
 45        return "/somewhere/dynamic/"
 46
 47    thing = models.FileField(upload_to=upload_to)
 48
 49
 50class TextEnum(enum.Enum):
 51    A = "a-value"
 52    B = "value-b"
 53
 54
 55class TextTranslatedEnum(enum.Enum):
 56    A = _("a-value")
 57    B = _("value-b")
 58
 59
 60class BinaryEnum(enum.Enum):
 61    A = b"a-value"
 62    B = b"value-b"
 63
 64
 65class IntEnum(enum.IntEnum):
 66    A = 1
 67    B = 2
 68
 69
 70class OperationWriterTests(SimpleTestCase):
 71    def test_empty_signature(self):
 72        operation = custom_migration_operations.operations.TestOperation()
 73        buff, imports = OperationWriter(operation, indentation=0).serialize()
 74        self.assertEqual(imports, {"import custom_migration_operations.operations"})
 75        self.assertEqual(
 76            buff, "custom_migration_operations.operations.TestOperation(\n" "),"
 77        )
 78
 79    def test_args_signature(self):
 80        operation = custom_migration_operations.operations.ArgsOperation(1, 2)
 81        buff, imports = OperationWriter(operation, indentation=0).serialize()
 82        self.assertEqual(imports, {"import custom_migration_operations.operations"})
 83        self.assertEqual(
 84            buff,
 85            "custom_migration_operations.operations.ArgsOperation(\n"
 86            "    arg1=1,\n"
 87            "    arg2=2,\n"
 88            "),",
 89        )
 90
 91    def test_kwargs_signature(self):
 92        operation = custom_migration_operations.operations.KwargsOperation(kwarg1=1)
 93        buff, imports = OperationWriter(operation, indentation=0).serialize()
 94        self.assertEqual(imports, {"import custom_migration_operations.operations"})
 95        self.assertEqual(
 96            buff,
 97            "custom_migration_operations.operations.KwargsOperation(\n"
 98            "    kwarg1=1,\n"
 99            "),",
100        )
101
102    def test_args_kwargs_signature(self):
103        operation = custom_migration_operations.operations.ArgsKwargsOperation(
104            1, 2, kwarg2=4
105        )
106        buff, imports = OperationWriter(operation, indentation=0).serialize()
107        self.assertEqual(imports, {"import custom_migration_operations.operations"})
108        self.assertEqual(
109            buff,
110            "custom_migration_operations.operations.ArgsKwargsOperation(\n"
111            "    arg1=1,\n"
112            "    arg2=2,\n"
113            "    kwarg2=4,\n"
114            "),",
115        )
116
117    def test_nested_args_signature(self):
118        operation = custom_migration_operations.operations.ArgsOperation(
119            custom_migration_operations.operations.ArgsOperation(1, 2),
120            custom_migration_operations.operations.KwargsOperation(kwarg1=3, kwarg2=4),
121        )
122        buff, imports = OperationWriter(operation, indentation=0).serialize()
123        self.assertEqual(imports, {"import custom_migration_operations.operations"})
124        self.assertEqual(
125            buff,
126            "custom_migration_operations.operations.ArgsOperation(\n"
127            "    arg1=custom_migration_operations.operations.ArgsOperation(\n"
128            "        arg1=1,\n"
129            "        arg2=2,\n"
130            "    ),\n"
131            "    arg2=custom_migration_operations.operations.KwargsOperation(\n"
132            "        kwarg1=3,\n"
133            "        kwarg2=4,\n"
134            "    ),\n"
135            "),",
136        )
137
138    def test_multiline_args_signature(self):
139        operation = custom_migration_operations.operations.ArgsOperation(
140            "test\n    arg1", "test\narg2"
141        )
142        buff, imports = OperationWriter(operation, indentation=0).serialize()
143        self.assertEqual(imports, {"import custom_migration_operations.operations"})
144        self.assertEqual(
145            buff,
146            "custom_migration_operations.operations.ArgsOperation(\n"
147            "    arg1='test\\n    arg1',\n"
148            "    arg2='test\\narg2',\n"
149            "),",
150        )
151
152    def test_expand_args_signature(self):
153        operation = custom_migration_operations.operations.ExpandArgsOperation([1, 2])
154        buff, imports = OperationWriter(operation, indentation=0).serialize()
155        self.assertEqual(imports, {"import custom_migration_operations.operations"})
156        self.assertEqual(
157            buff,
158            "custom_migration_operations.operations.ExpandArgsOperation(\n"
159            "    arg=[\n"
160            "        1,\n"
161            "        2,\n"
162            "    ],\n"
163            "),",
164        )
165
166    def test_nested_operation_expand_args_signature(self):
167        operation = custom_migration_operations.operations.ExpandArgsOperation(
168            arg=[
169                custom_migration_operations.operations.KwargsOperation(
170                    kwarg1=1, kwarg2=2,
171                ),
172            ]
173        )
174        buff, imports = OperationWriter(operation, indentation=0).serialize()
175        self.assertEqual(imports, {"import custom_migration_operations.operations"})
176        self.assertEqual(
177            buff,
178            "custom_migration_operations.operations.ExpandArgsOperation(\n"
179            "    arg=[\n"
180            "        custom_migration_operations.operations.KwargsOperation(\n"
181            "            kwarg1=1,\n"
182            "            kwarg2=2,\n"
183            "        ),\n"
184            "    ],\n"
185            "),",
186        )
187
188
189class WriterTests(SimpleTestCase):
190    """
191    Tests the migration writer (makes migration files from Migration instances)
192    """
193
194    class NestedEnum(enum.IntEnum):
195        A = 1
196        B = 2
197
198    class NestedChoices(models.TextChoices):
199        X = "X", "X value"
200        Y = "Y", "Y value"
201
202    def safe_exec(self, string, value=None):
203        d = {}
204        try:
205            exec(string, globals(), d)
206        except Exception as e:
207            if value:
208                self.fail(
209                    "Could not exec %r (from value %r): %s" % (string.strip(), value, e)
210                )
211            else:
212                self.fail("Could not exec %r: %s" % (string.strip(), e))
213        return d
214
215    def serialize_round_trip(self, value):
216        string, imports = MigrationWriter.serialize(value)
217        return self.safe_exec(
218            "%s\ntest_value_result = %s" % ("\n".join(imports), string), value
219        )["test_value_result"]
220
221    def assertSerializedEqual(self, value):
222        self.assertEqual(self.serialize_round_trip(value), value)
223
224    def assertSerializedResultEqual(self, value, target):
225        self.assertEqual(MigrationWriter.serialize(value), target)
226
227    def assertSerializedFieldEqual(self, value):
228        new_value = self.serialize_round_trip(value)
229        self.assertEqual(value.__class__, new_value.__class__)
230        self.assertEqual(value.max_length, new_value.max_length)
231        self.assertEqual(value.null, new_value.null)
232        self.assertEqual(value.unique, new_value.unique)
233
234    def test_serialize_numbers(self):
235        self.assertSerializedEqual(1)
236        self.assertSerializedEqual(1.2)
237        self.assertTrue(math.isinf(self.serialize_round_trip(float("inf"))))
238        self.assertTrue(math.isinf(self.serialize_round_trip(float("-inf"))))
239        self.assertTrue(math.isnan(self.serialize_round_trip(float("nan"))))
240
241        self.assertSerializedEqual(decimal.Decimal("1.3"))
242        self.assertSerializedResultEqual(
243            decimal.Decimal("1.3"), ("Decimal('1.3')", {"from decimal import Decimal"})
244        )
245
246        self.assertSerializedEqual(Money("1.3"))
247        self.assertSerializedResultEqual(
248            Money("1.3"),
249            ("migrations.test_writer.Money('1.3')", {"import migrations.test_writer"}),
250        )
251
252    def test_serialize_constants(self):
253        self.assertSerializedEqual(None)
254        self.assertSerializedEqual(True)
255        self.assertSerializedEqual(False)
256
257    def test_serialize_strings(self):
258        self.assertSerializedEqual(b"foobar")
259        string, imports = MigrationWriter.serialize(b"foobar")
260        self.assertEqual(string, "b'foobar'")
261        self.assertSerializedEqual("föobár")
262        string, imports = MigrationWriter.serialize("foobar")
263        self.assertEqual(string, "'foobar'")
264
265    def test_serialize_multiline_strings(self):
266        self.assertSerializedEqual(b"foo\nbar")
267        string, imports = MigrationWriter.serialize(b"foo\nbar")
268        self.assertEqual(string, "b'foo\\nbar'")
269        self.assertSerializedEqual("föo\nbár")
270        string, imports = MigrationWriter.serialize("foo\nbar")
271        self.assertEqual(string, "'foo\\nbar'")
272
273    def test_serialize_collections(self):
274        self.assertSerializedEqual({1: 2})
275        self.assertSerializedEqual(["a", 2, True, None])
276        self.assertSerializedEqual({2, 3, "eighty"})
277        self.assertSerializedEqual({"lalalala": ["yeah", "no", "maybe"]})
278        self.assertSerializedEqual(_("Hello"))
279
280    def test_serialize_builtin_types(self):
281        self.assertSerializedEqual([list, tuple, dict, set, frozenset])
282        self.assertSerializedResultEqual(
283            [list, tuple, dict, set, frozenset],
284            ("[list, tuple, dict, set, frozenset]", set()),
285        )
286
287    def test_serialize_lazy_objects(self):
288        pattern = re.compile(r"^foo$")
289        lazy_pattern = SimpleLazyObject(lambda: pattern)
290        self.assertEqual(self.serialize_round_trip(lazy_pattern), pattern)
291
292    def test_serialize_enums(self):
293        self.assertSerializedResultEqual(
294            TextEnum.A,
295            ("migrations.test_writer.TextEnum['A']", {"import migrations.test_writer"}),
296        )
297        self.assertSerializedResultEqual(
298            TextTranslatedEnum.A,
299            (
300                "migrations.test_writer.TextTranslatedEnum['A']",
301                {"import migrations.test_writer"},
302            ),
303        )
304        self.assertSerializedResultEqual(
305            BinaryEnum.A,
306            (
307                "migrations.test_writer.BinaryEnum['A']",
308                {"import migrations.test_writer"},
309            ),
310        )
311        self.assertSerializedResultEqual(
312            IntEnum.B,
313            ("migrations.test_writer.IntEnum['B']", {"import migrations.test_writer"}),
314        )
315        self.assertSerializedResultEqual(
316            self.NestedEnum.A,
317            (
318                "migrations.test_writer.WriterTests.NestedEnum['A']",
319                {"import migrations.test_writer"},
320            ),
321        )
322        self.assertSerializedEqual(self.NestedEnum.A)
323
324        field = models.CharField(
325            default=TextEnum.B, choices=[(m.value, m) for m in TextEnum]
326        )
327        string = MigrationWriter.serialize(field)[0]
328        self.assertEqual(
329            string,
330            "models.CharField(choices=["
331            "('a-value', migrations.test_writer.TextEnum['A']), "
332            "('value-b', migrations.test_writer.TextEnum['B'])], "
333            "default=migrations.test_writer.TextEnum['B'])",
334        )
335        field = models.CharField(
336            default=TextTranslatedEnum.A,
337            choices=[(m.value, m) for m in TextTranslatedEnum],
338        )
339        string = MigrationWriter.serialize(field)[0]
340        self.assertEqual(
341            string,
342            "models.CharField(choices=["
343            "('a-value', migrations.test_writer.TextTranslatedEnum['A']), "
344            "('value-b', migrations.test_writer.TextTranslatedEnum['B'])], "
345            "default=migrations.test_writer.TextTranslatedEnum['A'])",
346        )
347        field = models.CharField(
348            default=BinaryEnum.B, choices=[(m.value, m) for m in BinaryEnum]
349        )
350        string = MigrationWriter.serialize(field)[0]
351        self.assertEqual(
352            string,
353            "models.CharField(choices=["
354            "(b'a-value', migrations.test_writer.BinaryEnum['A']), "
355            "(b'value-b', migrations.test_writer.BinaryEnum['B'])], "
356            "default=migrations.test_writer.BinaryEnum['B'])",
357        )
358        field = models.IntegerField(
359            default=IntEnum.A, choices=[(m.value, m) for m in IntEnum]
360        )
361        string = MigrationWriter.serialize(field)[0]
362        self.assertEqual(
363            string,
364            "models.IntegerField(choices=["
365            "(1, migrations.test_writer.IntEnum['A']), "
366            "(2, migrations.test_writer.IntEnum['B'])], "
367            "default=migrations.test_writer.IntEnum['A'])",
368        )
369
370    def test_serialize_choices(self):
371        class TextChoices(models.TextChoices):
372            A = "A", "A value"
373            B = "B", "B value"
374
375        class IntegerChoices(models.IntegerChoices):
376            A = 1, "One"
377            B = 2, "Two"
378
379        class DateChoices(datetime.date, models.Choices):
380            DATE_1 = 1969, 7, 20, "First date"
381            DATE_2 = 1969, 11, 19, "Second date"
382
383        self.assertSerializedResultEqual(TextChoices.A, ("'A'", set()))
384        self.assertSerializedResultEqual(IntegerChoices.A, ("1", set()))
385        self.assertSerializedResultEqual(
386            DateChoices.DATE_1, ("datetime.date(1969, 7, 20)", {"import datetime"}),
387        )
388        field = models.CharField(default=TextChoices.B, choices=TextChoices.choices)
389        string = MigrationWriter.serialize(field)[0]
390        self.assertEqual(
391            string,
392            "models.CharField(choices=[('A', 'A value'), ('B', 'B value')], "
393            "default='B')",
394        )
395        field = models.IntegerField(
396            default=IntegerChoices.B, choices=IntegerChoices.choices
397        )
398        string = MigrationWriter.serialize(field)[0]
399        self.assertEqual(
400            string, "models.IntegerField(choices=[(1, 'One'), (2, 'Two')], default=2)",
401        )
402        field = models.DateField(
403            default=DateChoices.DATE_2, choices=DateChoices.choices
404        )
405        string = MigrationWriter.serialize(field)[0]
406        self.assertEqual(
407            string,
408            "models.DateField(choices=["
409            "(datetime.date(1969, 7, 20), 'First date'), "
410            "(datetime.date(1969, 11, 19), 'Second date')], "
411            "default=datetime.date(1969, 11, 19))",
412        )
413
414    def test_serialize_nested_class(self):
415        for nested_cls in [self.NestedEnum, self.NestedChoices]:
416            cls_name = nested_cls.__name__
417            with self.subTest(cls_name):
418                self.assertSerializedResultEqual(
419                    nested_cls,
420                    (
421                        "migrations.test_writer.WriterTests.%s" % cls_name,
422                        {"import migrations.test_writer"},
423                    ),
424                )
425
426    def test_serialize_uuid(self):
427        self.assertSerializedEqual(uuid.uuid1())
428        self.assertSerializedEqual(uuid.uuid4())
429
430        uuid_a = uuid.UUID("5c859437-d061-4847-b3f7-e6b78852f8c8")
431        uuid_b = uuid.UUID("c7853ec1-2ea3-4359-b02d-b54e8f1bcee2")
432        self.assertSerializedResultEqual(
433            uuid_a,
434            ("uuid.UUID('5c859437-d061-4847-b3f7-e6b78852f8c8')", {"import uuid"}),
435        )
436        self.assertSerializedResultEqual(
437            uuid_b,
438            ("uuid.UUID('c7853ec1-2ea3-4359-b02d-b54e8f1bcee2')", {"import uuid"}),
439        )
440
441        field = models.UUIDField(
442            choices=((uuid_a, "UUID A"), (uuid_b, "UUID B")), default=uuid_a
443        )
444        string = MigrationWriter.serialize(field)[0]
445        self.assertEqual(
446            string,
447            "models.UUIDField(choices=["
448            "(uuid.UUID('5c859437-d061-4847-b3f7-e6b78852f8c8'), 'UUID A'), "
449            "(uuid.UUID('c7853ec1-2ea3-4359-b02d-b54e8f1bcee2'), 'UUID B')], "
450            "default=uuid.UUID('5c859437-d061-4847-b3f7-e6b78852f8c8'))",
451        )
452
453    def test_serialize_functions(self):
454        with self.assertRaisesMessage(ValueError, "Cannot serialize function: lambda"):
455            self.assertSerializedEqual(lambda x: 42)
456        self.assertSerializedEqual(models.SET_NULL)
457        string, imports = MigrationWriter.serialize(models.SET(42))
458        self.assertEqual(string, "models.SET(42)")
459        self.serialize_round_trip(models.SET(42))
460
461    def test_serialize_datetime(self):
462        self.assertSerializedEqual(datetime.datetime.utcnow())
463        self.assertSerializedEqual(datetime.datetime.utcnow)
464        self.assertSerializedEqual(datetime.datetime.today())
465        self.assertSerializedEqual(datetime.datetime.today)
466        self.assertSerializedEqual(datetime.date.today())
467        self.assertSerializedEqual(datetime.date.today)
468        self.assertSerializedEqual(datetime.datetime.now().time())
469        self.assertSerializedEqual(
470            datetime.datetime(2014, 1, 1, 1, 1, tzinfo=get_default_timezone())
471        )
472        self.assertSerializedEqual(
473            datetime.datetime(2013, 12, 31, 22, 1, tzinfo=get_fixed_timezone(180))
474        )
475        self.assertSerializedResultEqual(
476            datetime.datetime(2014, 1, 1, 1, 1),
477            ("datetime.datetime(2014, 1, 1, 1, 1)", {"import datetime"}),
478        )
479        self.assertSerializedResultEqual(
480            datetime.datetime(2012, 1, 1, 1, 1, tzinfo=utc),
481            (
482                "datetime.datetime(2012, 1, 1, 1, 1, tzinfo=utc)",
483                {"import datetime", "from django.utils.timezone import utc"},
484            ),
485        )
486
487    def test_serialize_fields(self):
488        self.assertSerializedFieldEqual(models.CharField(max_length=255))
489        self.assertSerializedResultEqual(
490            models.CharField(max_length=255),
491            ("models.CharField(max_length=255)", {"from django.db import models"}),
492        )
493        self.assertSerializedFieldEqual(models.TextField(null=True, blank=True))
494        self.assertSerializedResultEqual(
495            models.TextField(null=True, blank=True),
496            (
497                "models.TextField(blank=True, null=True)",
498                {"from django.db import models"},
499            ),
500        )
501
502    def test_serialize_settings(self):
503        self.assertSerializedEqual(
504            SettingsReference(settings.AUTH_USER_MODEL, "AUTH_USER_MODEL")
505        )
506        self.assertSerializedResultEqual(
507            SettingsReference("someapp.model", "AUTH_USER_MODEL"),
508            ("settings.AUTH_USER_MODEL", {"from django.conf import settings"}),
509        )
510
511    def test_serialize_iterators(self):
512        self.assertSerializedResultEqual(
513            ((x, x * x) for x in range(3)), ("((0, 0), (1, 1), (2, 4))", set())
514        )
515
516    def test_serialize_compiled_regex(self):
517        """
518        Make sure compiled regex can be serialized.
519        """
520        regex = re.compile(r"^\w+$")
521        self.assertSerializedEqual(regex)
522
523    def test_serialize_class_based_validators(self):
524        """
525        Ticket #22943: Test serialization of class-based validators, including
526        compiled regexes.
527        """
528        validator = RegexValidator(message="hello")
529        string = MigrationWriter.serialize(validator)[0]
530        self.assertEqual(
531            string, "django.core.validators.RegexValidator(message='hello')"
532        )
533        self.serialize_round_trip(validator)
534
535        # Test with a compiled regex.
536        validator = RegexValidator(regex=re.compile(r"^\w+$"))
537        string = MigrationWriter.serialize(validator)[0]
538        self.assertEqual(
539            string,
540            "django.core.validators.RegexValidator(regex=re.compile('^\\\\w+$'))",
541        )
542        self.serialize_round_trip(validator)
543
544        # Test a string regex with flag
545        validator = RegexValidator(r"^[0-9]+$", flags=re.S)
546        string = MigrationWriter.serialize(validator)[0]
547        self.assertEqual(
548            string,
549            "django.core.validators.RegexValidator('^[0-9]+$', flags=re.RegexFlag['DOTALL'])",
550        )
551        self.serialize_round_trip(validator)
552
553        # Test message and code
554        validator = RegexValidator("^[-a-zA-Z0-9_]+$", "Invalid", "invalid")
555        string = MigrationWriter.serialize(validator)[0]
556        self.assertEqual(
557            string,
558            "django.core.validators.RegexValidator('^[-a-zA-Z0-9_]+$', 'Invalid', 'invalid')",
559        )
560        self.serialize_round_trip(validator)
561
562        # Test with a subclass.
563        validator = EmailValidator(message="hello")
564        string = MigrationWriter.serialize(validator)[0]
565        self.assertEqual(
566            string, "django.core.validators.EmailValidator(message='hello')"
567        )
568        self.serialize_round_trip(validator)
569
570        validator = deconstructible(path="migrations.test_writer.EmailValidator")(
571            EmailValidator
572        )(message="hello")
573        string = MigrationWriter.serialize(validator)[0]
574        self.assertEqual(
575            string, "migrations.test_writer.EmailValidator(message='hello')"
576        )
577
578        validator = deconstructible(path="custom.EmailValidator")(EmailValidator)(
579            message="hello"
580        )
581        with self.assertRaisesMessage(ImportError, "No module named 'custom'"):
582            MigrationWriter.serialize(validator)
583
584        validator = deconstructible(path="django.core.validators.EmailValidator2")(
585            EmailValidator
586        )(message="hello")
587        with self.assertRaisesMessage(
588            ValueError,
589            "Could not find object EmailValidator2 in django.core.validators.",
590        ):
591            MigrationWriter.serialize(validator)
592
593    def test_serialize_empty_nonempty_tuple(self):
594        """
595        Ticket #22679: makemigrations generates invalid code for (an empty
596        tuple) default_permissions = ()
597        """
598        empty_tuple = ()
599        one_item_tuple = ("a",)
600        many_items_tuple = ("a", "b", "c")
601        self.assertSerializedEqual(empty_tuple)
602        self.assertSerializedEqual(one_item_tuple)
603        self.assertSerializedEqual(many_items_tuple)
604
605    def test_serialize_range(self):
606        string, imports = MigrationWriter.serialize(range(1, 5))
607        self.assertEqual(string, "range(1, 5)")
608        self.assertEqual(imports, set())
609
610    def test_serialize_builtins(self):
611        string, imports = MigrationWriter.serialize(range)
612        self.assertEqual(string, "range")
613        self.assertEqual(imports, set())
614
615    def test_serialize_unbound_method_reference(self):
616        """An unbound method used within a class body can be serialized."""
617        self.serialize_round_trip(TestModel1.thing)
618
619    def test_serialize_local_function_reference(self):
620        """A reference in a local scope can't be serialized."""
621
622        class TestModel2:
623            def upload_to(self):
624                return "somewhere dynamic"
625
626            thing = models.FileField(upload_to=upload_to)
627
628        with self.assertRaisesMessage(
629            ValueError, "Could not find function upload_to in migrations.test_writer"
630        ):
631            self.serialize_round_trip(TestModel2.thing)
632
633    def test_serialize_managers(self):
634        self.assertSerializedEqual(models.Manager())
635        self.assertSerializedResultEqual(
636            FoodQuerySet.as_manager(),
637            (
638                "migrations.models.FoodQuerySet.as_manager()",
639                {"import migrations.models"},
640            ),
641        )
642        self.assertSerializedEqual(FoodManager("a", "b"))
643        self.assertSerializedEqual(FoodManager("x", "y", c=3, d=4))
644
645    def test_serialize_frozensets(self):
646        self.assertSerializedEqual(frozenset())
647        self.assertSerializedEqual(frozenset("let it go"))
648
649    def test_serialize_set(self):
650        self.assertSerializedEqual(set())
651        self.assertSerializedResultEqual(set(), ("set()", set()))
652        self.assertSerializedEqual({"a"})
653        self.assertSerializedResultEqual({"a"}, ("{'a'}", set()))
654
655    def test_serialize_timedelta(self):
656        self.assertSerializedEqual(datetime.timedelta())
657        self.assertSerializedEqual(datetime.timedelta(minutes=42))
658
659    def test_serialize_functools_partial(self):
660        value = functools.partial(datetime.timedelta, 1, seconds=2)
661        result = self.serialize_round_trip(value)
662        self.assertEqual(result.func, value.func)
663        self.assertEqual(result.args, value.args)
664        self.assertEqual(result.keywords, value.keywords)
665
666    def test_serialize_functools_partialmethod(self):
667        value = functools.partialmethod(datetime.timedelta, 1, seconds=2)
668        result = self.serialize_round_trip(value)
669        self.assertIsInstance(result, functools.partialmethod)
670        self.assertEqual(result.func, value.func)
671        self.assertEqual(result.args, value.args)
672        self.assertEqual(result.keywords, value.keywords)
673
674    def test_serialize_type_none(self):
675        self.assertSerializedEqual(type(None))
676
677    def test_simple_migration(self):
678        """
679        Tests serializing a simple migration.
680        """
681        fields = {
682            "charfield": models.DateTimeField(default=datetime.datetime.utcnow),
683            "datetimefield": models.DateTimeField(default=datetime.datetime.utcnow),
684        }
685
686        options = {
687            "verbose_name": "My model",
688            "verbose_name_plural": "My models",
689        }
690
691        migration = type(
692            "Migration",
693            (migrations.Migration,),
694            {
695                "operations": [
696                    migrations.CreateModel(
697                        "MyModel", tuple(fields.items()), options, (models.Model,)
698                    ),
699                    migrations.CreateModel(
700                        "MyModel2", tuple(fields.items()), bases=(models.Model,)
701                    ),
702                    migrations.CreateModel(
703                        name="MyModel3",
704                        fields=tuple(fields.items()),
705                        options=options,
706                        bases=(models.Model,),
707                    ),
708                    migrations.DeleteModel("MyModel"),
709                    migrations.AddField(
710                        "OtherModel", "datetimefield", fields["datetimefield"]
711                    ),
712                ],
713                "dependencies": [("testapp", "some_other_one")],
714            },
715        )
716        writer = MigrationWriter(migration)
717        output = writer.as_string()
718        # We don't test the output formatting - that's too fragile.
719        # Just make sure it runs for now, and that things look alright.
720        result = self.safe_exec(output)
721        self.assertIn("Migration", result)
722
723    def test_migration_path(self):
724        test_apps = [
725            "migrations.migrations_test_apps.normal",
726            "migrations.migrations_test_apps.with_package_model",
727            "migrations.migrations_test_apps.without_init_file",
728        ]
729
730        base_dir = os.path.dirname(os.path.dirname(__file__))
731
732        for app in test_apps:
733            with self.modify_settings(INSTALLED_APPS={"append": app}):
734                migration = migrations.Migration("0001_initial", app.split(".")[-1])
735                expected_path = os.path.join(
736                    base_dir, *(app.split(".") + ["migrations", "0001_initial.py"])
737                )
738                writer = MigrationWriter(migration)
739                self.assertEqual(writer.path, expected_path)
740
741    def test_custom_operation(self):
742        migration = type(
743            "Migration",
744            (migrations.Migration,),
745            {
746                "operations": [
747                    custom_migration_operations.operations.TestOperation(),
748                    custom_migration_operations.operations.CreateModel(),
749                    migrations.CreateModel("MyModel", (), {}, (models.Model,)),
750                    custom_migration_operations.more_operations.TestOperation(),
751                ],
752                "dependencies": [],
753            },
754        )
755        writer = MigrationWriter(migration)
756        output = writer.as_string()
757        result = self.safe_exec(output)
758        self.assertIn("custom_migration_operations", result)
759        self.assertNotEqual(
760            result["custom_migration_operations"].operations.TestOperation,
761            result["custom_migration_operations"].more_operations.TestOperation,
762        )
763
764    def test_sorted_imports(self):
765        """
766        #24155 - Tests ordering of imports.
767        """
768        migration = type(
769            "Migration",
770            (migrations.Migration,),
771            {
772                "operations": [
773                    migrations.AddField(
774                        "mymodel",
775                        "myfield",
776                        models.DateTimeField(
777                            default=datetime.datetime(2012, 1, 1, 1, 1, tzinfo=utc),
778                        ),
779                    ),
780                ]
781            },
782        )
783        writer = MigrationWriter(migration)
784        output = writer.as_string()
785        self.assertIn(
786            "import datetime\n"
787            "from django.db import migrations, models\n"
788            "from django.utils.timezone import utc\n",
789            output,
790        )
791
792    def test_migration_file_header_comments(self):
793        """
794        Test comments at top of file.
795        """
796        migration = type("Migration", (migrations.Migration,), {"operations": []})
797        dt = datetime.datetime(2015, 7, 31, 4, 40, 0, 0, tzinfo=utc)
798        with mock.patch("django.db.migrations.writer.now", lambda: dt):
799            for include_header in (True, False):
800                with self.subTest(include_header=include_header):
801                    writer = MigrationWriter(migration, include_header)
802                    output = writer.as_string()
803
804                    self.assertEqual(
805                        include_header,
806                        output.startswith(
807                            "# Generated by Django %s on 2015-07-31 04:40\n\n"
808                            % get_version()
809                        ),
810                    )
811                    if not include_header:
812                        # Make sure the output starts with something that's not
813                        # a comment or indentation or blank line
814                        self.assertRegex(
815                            output.splitlines(keepends=True)[0], r"^[^#\s]+"
816                        )
817
818    def test_models_import_omitted(self):
819        """
820        django.db.models shouldn't be imported if unused.
821        """
822        migration = type(
823            "Migration",
824            (migrations.Migration,),
825            {
826                "operations": [
827                    migrations.AlterModelOptions(
828                        name="model",
829                        options={
830                            "verbose_name": "model",
831                            "verbose_name_plural": "models",
832                        },
833                    ),
834                ]
835            },
836        )
837        writer = MigrationWriter(migration)
838        output = writer.as_string()
839        self.assertIn("from django.db import migrations\n", output)
840
841    def test_deconstruct_class_arguments(self):
842        # Yes, it doesn't make sense to use a class as a default for a
843        # CharField. It does make sense for custom fields though, for example
844        # an enumfield that takes the enum class as an argument.
845        string = MigrationWriter.serialize(
846            models.CharField(default=DeconstructibleInstances)
847        )[0]
848        self.assertEqual(
849            string,
850            "models.CharField(default=migrations.test_writer.DeconstructibleInstances)",
851        )
852
853    def test_register_serializer(self):
854        class ComplexSerializer(BaseSerializer):
855            def serialize(self):
856                return "complex(%r)" % self.value, {}
857
858        MigrationWriter.register_serializer(complex, ComplexSerializer)
859        self.assertSerializedEqual(complex(1, 2))
860        MigrationWriter.unregister_serializer(complex)
861        with self.assertRaisesMessage(ValueError, "Cannot serialize: (1+2j)"):
862            self.assertSerializedEqual(complex(1, 2))
863
864    def test_register_non_serializer(self):
865        with self.assertRaisesMessage(
866            ValueError, "'TestModel1' must inherit from 'BaseSerializer'."
867        ):
868            MigrationWriter.register_serializer(complex, TestModel1)

tests/model_enums/tests.py

  1import datetime
  2import decimal
  3import ipaddress
  4import uuid
  5
  6from django.db import models
  7from django.test import SimpleTestCase
  8from django.utils.functional import Promise
  9from django.utils.translation import gettext_lazy as _
 10
 11
 12class Suit(models.IntegerChoices):
 13    DIAMOND = 1, _("Diamond")
 14    SPADE = 2, _("Spade")
 15    HEART = 3, _("Heart")
 16    CLUB = 4, _("Club")
 17
 18
 19class YearInSchool(models.TextChoices):
 20    FRESHMAN = "FR", _("Freshman")
 21    SOPHOMORE = "SO", _("Sophomore")
 22    JUNIOR = "JR", _("Junior")
 23    SENIOR = "SR", _("Senior")
 24    GRADUATE = "GR", _("Graduate")
 25
 26
 27class Vehicle(models.IntegerChoices):
 28    CAR = 1, "Carriage"
 29    TRUCK = 2
 30    JET_SKI = 3
 31
 32    __empty__ = _("(Unknown)")
 33
 34
 35class Gender(models.TextChoices):
 36    MALE = "M"
 37    FEMALE = "F"
 38    NOT_SPECIFIED = "X"
 39
 40    __empty__ = "(Undeclared)"
 41
 42
 43class ChoicesTests(SimpleTestCase):
 44    def test_integerchoices(self):
 45        self.assertEqual(
 46            Suit.choices, [(1, "Diamond"), (2, "Spade"), (3, "Heart"), (4, "Club")]
 47        )
 48        self.assertEqual(Suit.labels, ["Diamond", "Spade", "Heart", "Club"])
 49        self.assertEqual(Suit.values, [1, 2, 3, 4])
 50        self.assertEqual(Suit.names, ["DIAMOND", "SPADE", "HEART", "CLUB"])
 51
 52        self.assertEqual(repr(Suit.DIAMOND), "<Suit.DIAMOND: 1>")
 53        self.assertEqual(Suit.DIAMOND.label, "Diamond")
 54        self.assertEqual(Suit.DIAMOND.value, 1)
 55        self.assertEqual(Suit["DIAMOND"], Suit.DIAMOND)
 56        self.assertEqual(Suit(1), Suit.DIAMOND)
 57
 58        self.assertIsInstance(Suit, type(models.Choices))
 59        self.assertIsInstance(Suit.DIAMOND, Suit)
 60        self.assertIsInstance(Suit.DIAMOND.label, Promise)
 61        self.assertIsInstance(Suit.DIAMOND.value, int)
 62
 63    def test_integerchoices_auto_label(self):
 64        self.assertEqual(Vehicle.CAR.label, "Carriage")
 65        self.assertEqual(Vehicle.TRUCK.label, "Truck")
 66        self.assertEqual(Vehicle.JET_SKI.label, "Jet Ski")
 67
 68    def test_integerchoices_empty_label(self):
 69        self.assertEqual(Vehicle.choices[0], (None, "(Unknown)"))
 70        self.assertEqual(Vehicle.labels[0], "(Unknown)")
 71        self.assertIsNone(Vehicle.values[0])
 72        self.assertEqual(Vehicle.names[0], "__empty__")
 73
 74    def test_integerchoices_functional_api(self):
 75        Place = models.IntegerChoices("Place", "FIRST SECOND THIRD")
 76        self.assertEqual(Place.labels, ["First", "Second", "Third"])
 77        self.assertEqual(Place.values, [1, 2, 3])
 78        self.assertEqual(Place.names, ["FIRST", "SECOND", "THIRD"])
 79
 80    def test_integerchoices_containment(self):
 81        self.assertIn(Suit.DIAMOND, Suit)
 82        self.assertIn(1, Suit)
 83        self.assertNotIn(0, Suit)
 84
 85    def test_textchoices(self):
 86        self.assertEqual(
 87            YearInSchool.choices,
 88            [
 89                ("FR", "Freshman"),
 90                ("SO", "Sophomore"),
 91                ("JR", "Junior"),
 92                ("SR", "Senior"),
 93                ("GR", "Graduate"),
 94            ],
 95        )
 96        self.assertEqual(
 97            YearInSchool.labels,
 98            ["Freshman", "Sophomore", "Junior", "Senior", "Graduate"],
 99        )
100        self.assertEqual(YearInSchool.values, ["FR", "SO", "JR", "SR", "GR"])
101        self.assertEqual(
102            YearInSchool.names,
103            ["FRESHMAN", "SOPHOMORE", "JUNIOR", "SENIOR", "GRADUATE"],
104        )
105
106        self.assertEqual(repr(YearInSchool.FRESHMAN), "<YearInSchool.FRESHMAN: 'FR'>")
107        self.assertEqual(YearInSchool.FRESHMAN.label, "Freshman")
108        self.assertEqual(YearInSchool.FRESHMAN.value, "FR")
109        self.assertEqual(YearInSchool["FRESHMAN"], YearInSchool.FRESHMAN)
110        self.assertEqual(YearInSchool("FR"), YearInSchool.FRESHMAN)
111
112        self.assertIsInstance(YearInSchool, type(models.Choices))
113        self.assertIsInstance(YearInSchool.FRESHMAN, YearInSchool)
114        self.assertIsInstance(YearInSchool.FRESHMAN.label, Promise)
115        self.assertIsInstance(YearInSchool.FRESHMAN.value, str)
116
117    def test_textchoices_auto_label(self):
118        self.assertEqual(Gender.MALE.label, "Male")
119        self.assertEqual(Gender.FEMALE.label, "Female")
120        self.assertEqual(Gender.NOT_SPECIFIED.label, "Not Specified")
121
122    def test_textchoices_empty_label(self):
123        self.assertEqual(Gender.choices[0], (None, "(Undeclared)"))
124        self.assertEqual(Gender.labels[0], "(Undeclared)")
125        self.assertIsNone(Gender.values[0])
126        self.assertEqual(Gender.names[0], "__empty__")
127
128    def test_textchoices_functional_api(self):
129        Medal = models.TextChoices("Medal", "GOLD SILVER BRONZE")
130        self.assertEqual(Medal.labels, ["Gold", "Silver", "Bronze"])
131        self.assertEqual(Medal.values, ["GOLD", "SILVER", "BRONZE"])
132        self.assertEqual(Medal.names, ["GOLD", "SILVER", "BRONZE"])
133
134    def test_textchoices_containment(self):
135        self.assertIn(YearInSchool.FRESHMAN, YearInSchool)
136        self.assertIn("FR", YearInSchool)
137        self.assertNotIn("XX", YearInSchool)
138
139    def test_textchoices_blank_value(self):
140        class BlankStr(models.TextChoices):
141            EMPTY = "", "(Empty)"
142            ONE = "ONE", "One"
143
144        self.assertEqual(BlankStr.labels, ["(Empty)", "One"])
145        self.assertEqual(BlankStr.values, ["", "ONE"])
146        self.assertEqual(BlankStr.names, ["EMPTY", "ONE"])
147
148    def test_invalid_definition(self):
149        msg = "'str' object cannot be interpreted as an integer"
150        with self.assertRaisesMessage(TypeError, msg):
151
152            class InvalidArgumentEnum(models.IntegerChoices):
153                # A string is not permitted as the second argument to int().
154                ONE = 1, "X", "Invalid"
155
156        msg = "duplicate values found in <enum 'Fruit'>: PINEAPPLE -> APPLE"
157        with self.assertRaisesMessage(ValueError, msg):
158
159            class Fruit(models.IntegerChoices):
160                APPLE = 1, "Apple"
161                PINEAPPLE = 1, "Pineapple"
162
163    def test_str(self):
164        for test in [Gender, Suit, YearInSchool, Vehicle]:
165            for member in test:
166                with self.subTest(member=member):
167                    self.assertEqual(str(test[member.name]), str(member.value))
168
169
170class Separator(bytes, models.Choices):
171    FS = b"\x1c", "File Separator"
172    GS = b"\x1d", "Group Separator"
173    RS = b"\x1e", "Record Separator"
174    US = b"\x1f", "Unit Separator"
175
176
177class Constants(float, models.Choices):
178    PI = 3.141592653589793, "π"
179    TAU = 6.283185307179586, "τ"
180
181
182class Set(frozenset, models.Choices):
183    A = {1, 2}
184    B = {2, 3}
185    UNION = A | B
186    DIFFERENCE = A - B
187    INTERSECTION = A & B
188
189
190class MoonLandings(datetime.date, models.Choices):
191    APOLLO_11 = 1969, 7, 20, "Apollo 11 (Eagle)"
192    APOLLO_12 = 1969, 11, 19, "Apollo 12 (Intrepid)"
193    APOLLO_14 = 1971, 2, 5, "Apollo 14 (Antares)"
194    APOLLO_15 = 1971, 7, 30, "Apollo 15 (Falcon)"
195    APOLLO_16 = 1972, 4, 21, "Apollo 16 (Orion)"
196    APOLLO_17 = 1972, 12, 11, "Apollo 17 (Challenger)"
197
198
199class DateAndTime(datetime.datetime, models.Choices):
200    A = 2010, 10, 10, 10, 10, 10
201    B = 2011, 11, 11, 11, 11, 11
202    C = 2012, 12, 12, 12, 12, 12
203
204
205class MealTimes(datetime.time, models.Choices):
206    BREAKFAST = 7, 0
207    LUNCH = 13, 0
208    DINNER = 18, 30
209
210
211class Frequency(datetime.timedelta, models.Choices):
212    WEEK = 0, 0, 0, 0, 0, 0, 1, "Week"
213    DAY = 1, "Day"
214    HOUR = 0, 0, 0, 0, 0, 1, "Hour"
215    MINUTE = 0, 0, 0, 0, 1, "Hour"
216    SECOND = 0, 1, "Second"
217
218
219class Number(decimal.Decimal, models.Choices):
220    E = 2.718281828459045, "e"
221    PI = "3.141592653589793", "π"
222    TAU = decimal.Decimal("6.283185307179586"), "τ"
223
224
225class IPv4Address(ipaddress.IPv4Address, models.Choices):
226    LOCALHOST = "127.0.0.1", "Localhost"
227    GATEWAY = "192.168.0.1", "Gateway"
228    BROADCAST = "192.168.0.255", "Broadcast"
229
230
231class IPv6Address(ipaddress.IPv6Address, models.Choices):
232    LOCALHOST = "::1", "Localhost"
233    UNSPECIFIED = "::", "Unspecified"
234
235
236class IPv4Network(ipaddress.IPv4Network, models.Choices):
237    LOOPBACK = "127.0.0.0/8", "Loopback"
238    LINK_LOCAL = "169.254.0.0/16", "Link-Local"
239    PRIVATE_USE_A = "10.0.0.0/8", "Private-Use (Class A)"
240
241
242class IPv6Network(ipaddress.IPv6Network, models.Choices):
243    LOOPBACK = "::1/128", "Loopback"
244    UNSPECIFIED = "::/128", "Unspecified"
245    UNIQUE_LOCAL = "fc00::/7", "Unique-Local"
246    LINK_LOCAL_UNICAST = "fe80::/10", "Link-Local Unicast"
247
248
249class CustomChoicesTests(SimpleTestCase):
250    def test_labels_valid(self):
251        enums = (
252            Separator,
253            Constants,
254            Set,
255            MoonLandings,
256            DateAndTime,
257            MealTimes,
258            Frequency,
259            Number,
260            IPv4Address,
261            IPv6Address,
262            IPv4Network,
263            IPv6Network,
264        )
265        for choice_enum in enums:
266            with self.subTest(choice_enum.__name__):
267                self.assertNotIn(None, choice_enum.labels)
268
269    def test_bool_unsupported(self):
270        msg = "type 'bool' is not an acceptable base type"
271        with self.assertRaisesMessage(TypeError, msg):
272
273            class Boolean(bool, models.Choices):
274                pass
275
276    def test_timezone_unsupported(self):
277        msg = "type 'datetime.timezone' is not an acceptable base type"
278        with self.assertRaisesMessage(TypeError, msg):
279
280            class Timezone(datetime.timezone, models.Choices):
281                pass
282
283    def test_uuid_unsupported(self):
284        msg = "UUID objects are immutable"
285        with self.assertRaisesMessage(TypeError, msg):
286
287            class Identifier(uuid.UUID, models.Choices):
288                A = "972ce4eb-a95f-4a56-9339-68c208a76f18"

tests/model_fields/test_charfield.py

 1from unittest import skipIf
 2
 3from django.core.exceptions import ValidationError
 4from django.db import connection, models
 5from django.test import SimpleTestCase, TestCase
 6
 7from .models import Post
 8
 9
10class TestCharField(TestCase):
11    def test_max_length_passed_to_formfield(self):
12        """
13        CharField passes its max_length attribute to form fields created using
14        the formfield() method.
15        """
16        cf1 = models.CharField()
17        cf2 = models.CharField(max_length=1234)
18        self.assertIsNone(cf1.formfield().max_length)
19        self.assertEqual(1234, cf2.formfield().max_length)
20
21    def test_lookup_integer_in_charfield(self):
22        self.assertEqual(Post.objects.filter(title=9).count(), 0)
23
24    @skipIf(
25        connection.vendor == "mysql",
26        "Running on MySQL requires utf8mb4 encoding (#18392)",
27    )
28    def test_emoji(self):
29        p = Post.objects.create(title="Smile 😀", body="Whatever.")
30        p.refresh_from_db()
31        self.assertEqual(p.title, "Smile 😀")
32
33    def test_assignment_from_choice_enum(self):
34        class Event(models.TextChoices):
35            C = "Carnival!"
36            F = "Festival!"
37
38        p1 = Post.objects.create(title=Event.C, body=Event.F)
39        p1.refresh_from_db()
40        self.assertEqual(p1.title, "Carnival!")
41        self.assertEqual(p1.body, "Festival!")
42        self.assertEqual(p1.title, Event.C)
43        self.assertEqual(p1.body, Event.F)
44        p2 = Post.objects.get(title="Carnival!")
45        self.assertEqual(p1, p2)
46        self.assertEqual(p2.title, Event.C)
47
48
49class ValidationTests(SimpleTestCase):
50    class Choices(models.TextChoices):
51        C = "c", "C"
52
53    def test_charfield_raises_error_on_empty_string(self):
54        f = models.CharField()
55        with self.assertRaises(ValidationError):
56            f.clean("", None)
57
58    def test_charfield_cleans_empty_string_when_blank_true(self):
59        f = models.CharField(blank=True)
60        self.assertEqual("", f.clean("", None))
61
62    def test_charfield_with_choices_cleans_valid_choice(self):
63        f = models.CharField(max_length=1, choices=[("a", "A"), ("b", "B")])
64        self.assertEqual("a", f.clean("a", None))
65
66    def test_charfield_with_choices_raises_error_on_invalid_choice(self):
67        f = models.CharField(choices=[("a", "A"), ("b", "B")])
68        with self.assertRaises(ValidationError):
69            f.clean("not a", None)
70
71    def test_enum_choices_cleans_valid_string(self):
72        f = models.CharField(choices=self.Choices.choices, max_length=1)
73        self.assertEqual(f.clean("c", None), "c")
74
75    def test_enum_choices_invalid_input(self):
76        f = models.CharField(choices=self.Choices.choices, max_length=1)
77        with self.assertRaises(ValidationError):
78            f.clean("a", None)
79
80    def test_charfield_raises_error_on_empty_input(self):
81        f = models.CharField(null=False)
82        with self.assertRaises(ValidationError):
83            f.clean(None, None)