118
118
import time
119
119
import math
120
120
import datetime
121
+ import functools
121
122
122
123
import warnings
123
124
@@ -732,20 +733,105 @@ def __call__(self, x, pos=None):
732
733
733
734
734
735
class rrulewrapper (object ):
736
+ def __init__ (self , freq , tzinfo = None , ** kwargs ):
737
+ kwargs ['freq' ] = freq
738
+ self ._base_tzinfo = tzinfo
735
739
736
- def __init__ (self , freq , ** kwargs ):
737
- self ._construct = kwargs .copy ()
738
- self ._construct ["freq" ] = freq
739
- self ._rrule = rrule (** self ._construct )
740
+ self ._update_rrule (** kwargs )
740
741
741
742
def set (self , ** kwargs ):
742
743
self ._construct .update (kwargs )
744
+
745
+ self ._update_rrule (** self ._construct )
746
+
747
+ def _update_rrule (self , ** kwargs ):
748
+ tzinfo = self ._base_tzinfo
749
+
750
+ # rrule does not play nicely with time zones - especially pytz time
751
+ # zones, it's best to use naive zones and attach timezones once the
752
+ # datetimes are returned
753
+ if 'dtstart' in kwargs :
754
+ dtstart = kwargs ['dtstart' ]
755
+ if dtstart .tzinfo is not None :
756
+ if tzinfo is None :
757
+ tzinfo = dtstart .tzinfo
758
+ else :
759
+ dtstart = dtstart .astimezone (tzinfo )
760
+
761
+ kwargs ['dtstart' ] = dtstart .replace (tzinfo = None )
762
+
763
+ if 'until' in kwargs :
764
+ until = kwargs ['until' ]
765
+ if until .tzinfo is not None :
766
+ if tzinfo is not None :
767
+ until = until .astimezone (tzinfo )
768
+ else :
769
+ raise ValueError ('until cannot be aware if dtstart '
770
+ 'is naive and tzinfo is None' )
771
+
772
+ kwargs ['until' ] = until .replace (tzinfo = None )
773
+
774
+ self ._construct = kwargs .copy ()
775
+ self ._tzinfo = tzinfo
743
776
self ._rrule = rrule (** self ._construct )
744
777
778
+ def _attach_tzinfo (self , dt , tzinfo ):
779
+ # pytz zones are attached by "localizing" the datetime
780
+ if hasattr (tzinfo , 'localize' ):
781
+ return tzinfo .localize (dt , is_dst = True )
782
+
783
+ return dt .replace (tzinfo = tzinfo )
784
+
785
+ def _aware_return_wrapper (self , f , returns_list = False ):
786
+ """Decorator function that allows rrule methods to handle tzinfo."""
787
+ # This is only necessary if we're actually attaching a tzinfo
788
+ if self ._tzinfo is None :
789
+ return f
790
+
791
+ # All datetime arguments must be naive. If they are not naive, they are
792
+ # converted to the _tzinfo zone before dropping the zone.
793
+ def normalize_arg (arg ):
794
+ if isinstance (arg , datetime .datetime ) and arg .tzinfo is not None :
795
+ if arg .tzinfo is not self ._tzinfo :
796
+ arg = arg .astimezone (self ._tzinfo )
797
+
798
+ return arg .replace (tzinfo = None )
799
+
800
+ return arg
801
+
802
+ def normalize_args (args , kwargs ):
803
+ args = tuple (normalize_arg (arg ) for arg in args )
804
+ kwargs = {kw : normalize_arg (arg ) for kw , arg in kwargs .items ()}
805
+
806
+ return args , kwargs
807
+
808
+ # There are two kinds of functions we care about - ones that return
809
+ # dates and ones that return lists of dates.
810
+ if not returns_list :
811
+ def inner_func (* args , ** kwargs ):
812
+ args , kwargs = normalize_args (args , kwargs )
813
+ dt = f (* args , ** kwargs )
814
+ return self ._attach_tzinfo (dt , self ._tzinfo )
815
+ else :
816
+ def inner_func (* args , ** kwargs ):
817
+ args , kwargs = normalize_args (args , kwargs )
818
+ dts = f (* args , ** kwargs )
819
+ return [self ._attach_tzinfo (dt , self ._tzinfo ) for dt in dts ]
820
+
821
+ return functools .wraps (f )(inner_func )
822
+
745
823
def __getattr__ (self , name ):
746
824
if name in self .__dict__ :
747
825
return self .__dict__ [name ]
748
- return getattr (self ._rrule , name )
826
+
827
+ f = getattr (self ._rrule , name )
828
+
829
+ if name in {'after' , 'before' }:
830
+ return self ._aware_return_wrapper (f )
831
+ elif name in {'xafter' , 'xbefore' , 'between' }:
832
+ return self ._aware_return_wrapper (f , returns_list = True )
833
+ else :
834
+ return f
749
835
750
836
def __setstate__ (self , state ):
751
837
self .__dict__ .update (state )
@@ -1226,7 +1312,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None):
1226
1312
bymonth = [x .item () for x in bymonth .astype (int )]
1227
1313
1228
1314
rule = rrulewrapper (MONTHLY , bymonth = bymonth , bymonthday = bymonthday ,
1229
- interval = interval , ** self .hms0d )
1315
+ interval = interval , ** self .hms0d )
1230
1316
RRuleLocator .__init__ (self , rule , tz )
1231
1317
1232
1318
0 commit comments