34 lines
1.3 KiB
Python
34 lines
1.3 KiB
Python
from functools import partial
|
|
from itertools import groupby
|
|
from operator import attrgetter
|
|
|
|
from django.forms.models import ModelChoiceIterator, ModelChoiceField
|
|
|
|
|
|
class GroupedModelChoiceIterator(ModelChoiceIterator):
|
|
def __init__(self, field, groupby):
|
|
self.groupby = groupby
|
|
super().__init__(field)
|
|
|
|
def __iter__(self):
|
|
if self.field.empty_label is not None: # pragma: no cover
|
|
yield ("", self.field.empty_label)
|
|
queryset = self.queryset
|
|
# Can't use iterator() when queryset uses prefetch_related()
|
|
if not queryset._prefetch_related_lookups:
|
|
queryset = queryset.iterator()
|
|
for group, objs in groupby(queryset, self.groupby):
|
|
yield (group, [self.choice(obj) for obj in objs])
|
|
|
|
|
|
class GroupedModelChoiceField(ModelChoiceField):
|
|
def __init__(self, *args, choices_groupby, **kwargs):
|
|
if isinstance(choices_groupby, str):
|
|
choices_groupby = attrgetter(choices_groupby)
|
|
elif not callable(choices_groupby): # pragma: no cover
|
|
raise TypeError(
|
|
'choices_groupby must either be a str or a callable accepting a single argument')
|
|
self.iterator = partial(
|
|
GroupedModelChoiceIterator, groupby=choices_groupby)
|
|
super().__init__(*args, **kwargs)
|