@@ -714,6 +714,13 @@ def test_basics(self):
714
714
Literal ["x" , "y" , "z" ]
715
715
Literal [None ]
716
716
717
+ def test_enum (self ):
718
+ import enum
719
+ class My (enum .Enum ):
720
+ A = 'A'
721
+
722
+ self .assertEqual (Literal [My .A ].__args__ , (My .A ,))
723
+
717
724
def test_illegal_parameters_do_not_raise_runtime_errors (self ):
718
725
# Type checkers should reject these types, but we do not
719
726
# raise errors at runtime to maintain maximum flexibility
@@ -794,6 +801,64 @@ def test_args(self):
794
801
# Mutable arguments will not be deduplicated
795
802
self .assertEqual (Literal [[], []].__args__ , ([], []))
796
803
804
+ def test_union_of_literals (self ):
805
+ self .assertEqual (Union [Literal [1 ], Literal [2 ]].__args__ ,
806
+ (Literal [1 ], Literal [2 ]))
807
+ self .assertEqual (Union [Literal [1 ], Literal [1 ]],
808
+ Literal [1 ])
809
+
810
+ self .assertEqual (Union [Literal [False ], Literal [0 ]].__args__ ,
811
+ (Literal [False ], Literal [0 ]))
812
+ self .assertEqual (Union [Literal [True ], Literal [1 ]].__args__ ,
813
+ (Literal [True ], Literal [1 ]))
814
+
815
+ import enum
816
+ class Ints (enum .IntEnum ):
817
+ A = 0
818
+ B = 1
819
+
820
+ self .assertEqual (Union [Literal [Ints .A ], Literal [Ints .B ]].__args__ ,
821
+ (Literal [Ints .A ], Literal [Ints .B ]))
822
+
823
+ self .assertEqual (Union [Literal [Ints .A ], Literal [Ints .A ]],
824
+ Literal [Ints .A ])
825
+ self .assertEqual (Union [Literal [Ints .B ], Literal [Ints .B ]],
826
+ Literal [Ints .B ])
827
+
828
+ self .assertEqual (Union [Literal [0 ], Literal [Ints .A ], Literal [False ]].__args__ ,
829
+ (Literal [0 ], Literal [Ints .A ], Literal [False ]))
830
+ self .assertEqual (Union [Literal [1 ], Literal [Ints .B ], Literal [True ]].__args__ ,
831
+ (Literal [1 ], Literal [Ints .B ], Literal [True ]))
832
+
833
+ @skipUnless (TYPING_3_10_0 , "Python 3.10+ required" )
834
+ def test_or_type_operator_with_Literal (self ):
835
+ self .assertEqual ((Literal [1 ] | Literal [2 ]).__args__ ,
836
+ (Literal [1 ], Literal [2 ]))
837
+
838
+ self .assertEqual ((Literal [0 ] | Literal [False ]).__args__ ,
839
+ (Literal [0 ], Literal [False ]))
840
+ self .assertEqual ((Literal [1 ] | Literal [True ]).__args__ ,
841
+ (Literal [1 ], Literal [True ]))
842
+
843
+ self .assertEqual (Literal [1 ] | Literal [1 ], Literal [1 ])
844
+ self .assertEqual (Literal ['a' ] | Literal ['a' ], Literal ['a' ])
845
+
846
+ import enum
847
+ class Ints (enum .IntEnum ):
848
+ A = 0
849
+ B = 1
850
+
851
+ self .assertEqual (Literal [Ints .A ] | Literal [Ints .A ], Literal [Ints .A ])
852
+ self .assertEqual (Literal [Ints .B ] | Literal [Ints .B ], Literal [Ints .B ])
853
+
854
+ self .assertEqual ((Literal [Ints .B ] | Literal [Ints .A ]).__args__ ,
855
+ (Literal [Ints .B ], Literal [Ints .A ]))
856
+
857
+ self .assertEqual ((Literal [0 ] | Literal [Ints .A ]).__args__ ,
858
+ (Literal [0 ], Literal [Ints .A ]))
859
+ self .assertEqual ((Literal [1 ] | Literal [Ints .B ]).__args__ ,
860
+ (Literal [1 ], Literal [Ints .B ]))
861
+
797
862
def test_flatten (self ):
798
863
l1 = Literal [Literal [1 ], Literal [2 ], Literal [3 ]]
799
864
l2 = Literal [Literal [1 , 2 ], 3 ]
@@ -802,6 +867,20 @@ def test_flatten(self):
802
867
self .assertEqual (lit , Literal [1 , 2 , 3 ])
803
868
self .assertEqual (lit .__args__ , (1 , 2 , 3 ))
804
869
870
+ def test_does_not_flatten_enum (self ):
871
+ import enum
872
+ class Ints (enum .IntEnum ):
873
+ A = 1
874
+ B = 2
875
+
876
+ literal = Literal [
877
+ Literal [Ints .A ],
878
+ Literal [Ints .B ],
879
+ Literal [1 ],
880
+ Literal [2 ],
881
+ ]
882
+ self .assertEqual (literal .__args__ , (Ints .A , Ints .B , 1 , 2 ))
883
+
805
884
def test_caching_of_Literal_respects_type (self ):
806
885
self .assertIs (type (Literal [1 ].__args__ [0 ]), int )
807
886
self .assertIs (type (Literal [True ].__args__ [0 ]), bool )
0 commit comments