22
22
from __future__ import division
23
23
from __future__ import print_function
24
24
25
+ import collections
25
26
import csv
26
27
import io
27
28
import string
@@ -76,6 +77,24 @@ def __call__(cls, *args, **kwargs):
76
77
return type .__call__ (cls , * args )
77
78
78
79
80
+ # NOTE about Genericity and Metaclass of ArgumentParser.
81
+ # (1) In the .py source (this file)
82
+ # - is not declared as Generic
83
+ # - has _ArgumentParserCache as a metaclass
84
+ # (2) In the .pyi source (type stub)
85
+ # - is declared as Generic
86
+ # - doesn't have a metaclass
87
+ # The reason we need this is due to Generic having a different metaclass
88
+ # (for python versions <= 3.7) and a class can have only one metaclass.
89
+ #
90
+ # * Lack of metaclass in .pyi is not a deal breaker, since the metaclass
91
+ # doesn't affect any type information. Also type checkers can check the type
92
+ # parameters.
93
+ # * However, not declaring ArgumentParser as Generic in the source affects
94
+ # runtime annotation processing. In particular this means, subclasses should
95
+ # inherit from `ArgumentParser` and not `ArgumentParser[SomeType]`.
96
+ # The corresponding DEFINE_someType method (the public API) can be annotated
97
+ # to return FlagHolder[SomeType].
79
98
class ArgumentParser (six .with_metaclass (_ArgumentParserCache , object )):
80
99
"""Base class used to parse and convert arguments.
81
100
@@ -354,11 +373,13 @@ def flag_type(self):
354
373
class EnumClassParser (ArgumentParser ):
355
374
"""Parser of an Enum class member."""
356
375
357
- def __init__ (self , enum_class ):
376
+ def __init__ (self , enum_class , case_sensitive = True ):
358
377
"""Initializes EnumParser.
359
378
360
379
Args:
361
380
enum_class: class, the Enum class with all possible flag values.
381
+ case_sensitive: bool, whether or not the enum is to be case-sensitive. If
382
+ False, all member names must be unique when case is ignored.
362
383
363
384
Raises:
364
385
TypeError: When enum_class is not a subclass of Enum.
@@ -373,9 +394,30 @@ def __init__(self, enum_class):
373
394
if not enum_class .__members__ :
374
395
raise ValueError ('enum_class cannot be empty, but "{}" is empty.'
375
396
.format (enum_class ))
397
+ if not case_sensitive :
398
+ members = collections .Counter (
399
+ name .lower () for name in enum_class .__members__ )
400
+ duplicate_keys = {
401
+ member for member , count in members .items () if count > 1
402
+ }
403
+ if duplicate_keys :
404
+ raise ValueError (
405
+ 'Duplicate enum values for {} using case_sensitive=False' .format (
406
+ duplicate_keys ))
376
407
377
408
super (EnumClassParser , self ).__init__ ()
378
409
self .enum_class = enum_class
410
+ self ._case_sensitive = case_sensitive
411
+ if case_sensitive :
412
+ self ._member_names = tuple (enum_class .__members__ )
413
+ else :
414
+ self ._member_names = tuple (
415
+ name .lower () for name in enum_class .__members__ )
416
+
417
+ @property
418
+ def member_names (self ):
419
+ """The accepted enum names, in lowercase if not case sensitive."""
420
+ return self ._member_names
379
421
380
422
def parse (self , argument ):
381
423
"""Determines validity of argument and returns the correct element of enum.
@@ -391,11 +433,19 @@ def parse(self, argument):
391
433
"""
392
434
if isinstance (argument , self .enum_class ):
393
435
return argument
394
- if argument not in self .enum_class .__members__ :
395
- raise ValueError ('value should be one of <%s>' %
396
- '|' .join (self .enum_class .__members__ .keys ()))
436
+ elif not isinstance (argument , six .string_types ):
437
+ raise ValueError (
438
+ '{} is not an enum member or a name of a member in {}' .format (
439
+ argument , self .enum_class ))
440
+ key = EnumParser (
441
+ self ._member_names , case_sensitive = self ._case_sensitive ).parse (argument )
442
+ if self ._case_sensitive :
443
+ return self .enum_class [key ]
397
444
else :
398
- return self .enum_class [argument ]
445
+ # If EnumParser.parse() return a value, we're guaranteed to find it
446
+ # as a member of the class
447
+ return next (value for name , value in self .enum_class .__members__ .items ()
448
+ if name .lower () == key .lower ())
399
449
400
450
def flag_type (self ):
401
451
"""See base class."""
@@ -413,13 +463,30 @@ def serialize(self, value):
413
463
414
464
415
465
class EnumClassListSerializer (ListSerializer ):
466
+ """A serializer for MultiEnumClass flags.
467
+
468
+ This serializer simply joins the output of `EnumClassSerializer` using a
469
+ provided seperator.
470
+ """
471
+
472
+ def __init__ (self , list_sep , ** kwargs ):
473
+ """Initializes EnumClassListSerializer.
474
+
475
+ Args:
476
+ list_sep: String to be used as a separator when serializing
477
+ **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize
478
+ individual values.
479
+ """
480
+ super (EnumClassListSerializer , self ).__init__ (list_sep )
481
+ self ._element_serializer = EnumClassSerializer (** kwargs )
416
482
417
483
def serialize (self , value ):
418
484
"""See base class."""
419
485
if isinstance (value , list ):
420
- return self .list_sep .join (_helpers .str_or_unicode (x .name ) for x in value )
486
+ return self .list_sep .join (
487
+ self ._element_serializer .serialize (x ) for x in value )
421
488
else :
422
- return _helpers . str_or_unicode (value . name )
489
+ return self . _element_serializer . serialize (value )
423
490
424
491
425
492
class CsvListSerializer (ArgumentSerializer ):
@@ -448,9 +515,18 @@ def serialize(self, value):
448
515
class EnumClassSerializer (ArgumentSerializer ):
449
516
"""Class for generating string representations of an enum class flag value."""
450
517
518
+ def __init__ (self , lowercase ):
519
+ """Initializes EnumClassSerializer.
520
+
521
+ Args:
522
+ lowercase: If True, enum member names are lowercased during serialization.
523
+ """
524
+ self ._lowercase = lowercase
525
+
451
526
def serialize (self , value ):
452
527
"""Returns a serialized string of the Enum class value."""
453
- return _helpers .str_or_unicode (value .name )
528
+ as_string = _helpers .str_or_unicode (value .name )
529
+ return as_string .lower () if self ._lowercase else as_string
454
530
455
531
456
532
class BaseListParser (ArgumentParser ):
0 commit comments