Created
June 17, 2020 17:39
-
-
Save pmdarrow/97e36ae996296a84906fcacb3d44740c to your computer and use it in GitHub Desktop.
Simple marshmallow enum field with support for apispec
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Loosely based on https://github.com/h/marshmallow_enum | |
from marshmallow.fields import Field | |
class EnumField(Field): | |
default_error_messages = { | |
'invalid': 'Invalid enum value {input}', | |
} | |
def __init__(self, enum_type, *args, **kwargs): | |
self.enum = enum_type | |
super(EnumField, self).__init__(*args, **kwargs) | |
# Detect type of enum and make it available to apispec | |
values = [e.value for e in self.enum if e.value is not None] | |
if all(isinstance(v, int) for v in values): | |
self.metadata['type'] = 'integer' | |
elif all(isinstance(v, (float, int)) for v in values): | |
self.metadata['type'] = 'number' | |
elif all(isinstance(v, bool) for v in values): | |
self.metadata['type'] = 'boolean' | |
elif all(isinstance(v, str) for v in values): | |
self.metadata['type'] = 'string' | |
# Ensure all enum values are made available to apispec | |
self.metadata['enum'] = sorted([e.value for e in self.enum]) | |
def _serialize(self, value, attr, obj): | |
if value is None: | |
return None | |
return value.value | |
def _deserialize(self, value, attr, data, **kwargs): | |
if value is None: | |
return None | |
try: | |
return self.enum(value) | |
except ValueError: | |
self.fail('invalid', input=value, value=value) | |
def fail(self, key, **kwargs): | |
kwargs['values'] = ', '.join([str(mem.value) for mem in self.enum]) | |
super(EnumField, self).fail(key, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment