125
125
import time
126
126
import math
127
127
import datetime
128
+ import functools
128
129
129
130
import warnings
130
131
@@ -806,20 +807,105 @@ def __call__(self, x, pos=None):
806
807
807
808
808
809
class rrulewrapper (object ):
810
+ def __init__ (self , freq , tzinfo = None , ** kwargs ):
811
+ kwargs ['freq' ] = freq
812
+ self ._base_tzinfo = tzinfo
809
813
810
- def __init__ (self , freq , ** kwargs ):
811
- self ._construct = kwargs .copy ()
812
- self ._construct ["freq" ] = freq
813
- self ._rrule = rrule (** self ._construct )
814
+ self ._update_rrule (** kwargs )
814
815
815
816
def set (self , ** kwargs ):
816
817
self ._construct .update (kwargs )
818
+
819
+ self ._update_rrule (** self ._construct )
820
+
821
+ def _update_rrule (self , ** kwargs ):
822
+ tzinfo = self ._base_tzinfo
823
+
824
+ # rrule does not play nicely with time zones - especially pytz time
825
+ # zones, it's best to use naive zones and attach timezones once the
826
+ # datetimes are returned
827
+ if 'dtstart' in kwargs :
828
+ dtstart = kwargs ['dtstart' ]
829
+ if dtstart .tzinfo is not None :
830
+ if tzinfo is None :
831
+ tzinfo = dtstart .tzinfo
832
+ else :
833
+ dtstart = dtstart .astimezone (tzinfo )
834
+
835
+ kwargs ['dtstart' ] = dtstart .replace (tzinfo = None )
836
+
837
+ if 'until' in kwargs :
838
+ until = kwargs ['until' ]
839
+ if until .tzinfo is not None :
840
+ if tzinfo is not None :
841
+ until = until .astimezone (tzinfo )
842
+ else :
843
+ raise ValueError ('until cannot be aware if dtstart '
844
+ 'is naive and tzinfo is None' )
845
+
846
+ kwargs ['until' ] = until .replace (tzinfo = None )
847
+
848
+ self ._construct = kwargs .copy ()
849
+ self ._tzinfo = tzinfo
817
850
self ._rrule = rrule (** self ._construct )
818
851
852
+ def _attach_tzinfo (self , dt , tzinfo ):
853
+ # pytz zones are attached by "localizing" the datetime
854
+ if hasattr (tzinfo , 'localize' ):
855
+ return tzinfo .localize (dt , is_dst = True )
856
+
857
+ return dt .replace (tzinfo = tzinfo )
858
+
859
+ def _aware_return_wrapper (self , f , returns_list = False ):
860
+ """Decorator function that allows rrule methods to handle tzinfo."""
861
+ # This is only necessary if we're actually attaching a tzinfo
862
+ if self ._tzinfo is None :
863
+ return f
864
+
865
+ # All datetime arguments must be naive. If they are not naive, they are
866
+ # converted to the _tzinfo zone before dropping the zone.
867
+ def normalize_arg (arg ):
868
+ if isinstance (arg , datetime .datetime ) and arg .tzinfo is not None :
869
+ if arg .tzinfo is not self ._tzinfo :
870
+ arg = arg .astimezone (self ._tzinfo )
871
+
872
+ return arg .replace (tzinfo = None )
873
+
874
+ return arg
875
+
876
+ def normalize_args (args , kwargs ):
877
+ args = tuple (normalize_arg (arg ) for arg in args )
878
+ kwargs = {kw : normalize_arg (arg ) for kw , arg in kwargs .items ()}
879
+
880
+ return args , kwargs
881
+
882
+ # There are two kinds of functions we care about - ones that return
883
+ # dates and ones that return lists of dates.
884
+ if not returns_list :
885
+ def inner_func (* args , ** kwargs ):
886
+ args , kwargs = normalize_args (args , kwargs )
887
+ dt = f (* args , ** kwargs )
888
+ return self ._attach_tzinfo (dt , self ._tzinfo )
889
+ else :
890
+ def inner_func (* args , ** kwargs ):
891
+ args , kwargs = normalize_args (args , kwargs )
892
+ dts = f (* args , ** kwargs )
893
+ return [self ._attach_tzinfo (dt , self ._tzinfo ) for dt in dts ]
894
+
895
+ return functools .wraps (f )(inner_func )
896
+
819
897
def __getattr__ (self , name ):
820
898
if name in self .__dict__ :
821
899
return self .__dict__ [name ]
822
- return getattr (self ._rrule , name )
900
+
901
+ f = getattr (self ._rrule , name )
902
+
903
+ if name in {'after' , 'before' }:
904
+ return self ._aware_return_wrapper (f )
905
+ elif name in {'xafter' , 'xbefore' , 'between' }:
906
+ return self ._aware_return_wrapper (f , returns_list = True )
907
+ else :
908
+ return f
823
909
824
910
def __setstate__ (self , state ):
825
911
self .__dict__ .update (state )
@@ -1304,7 +1390,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None):
1304
1390
bymonth = [x .item () for x in bymonth .astype (int )]
1305
1391
1306
1392
rule = rrulewrapper (MONTHLY , bymonth = bymonth , bymonthday = bymonthday ,
1307
- interval = interval , ** self .hms0d )
1393
+ interval = interval , ** self .hms0d )
1308
1394
RRuleLocator .__init__ (self , rule , tz )
1309
1395
1310
1396
0 commit comments