File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed
Filter options
Expand file tree Collapse file tree 1 file changed +16
-2
lines changed
Original file line number Diff line number Diff line change 19
19
#include < af/array.h>
20
20
#include < af/data.h>
21
21
#include < af/defines.h>
22
+ #include < af/device.h>
22
23
#include < af/dim4.hpp>
24
+ #include < map>
23
25
#include < memory>
24
26
25
27
using af::dim4;
@@ -128,8 +130,20 @@ af_err af_get_default_random_engine(af_random_engine *r) {
128
130
try {
129
131
AF_CHECK (af_init ());
130
132
131
- thread_local auto *re = new RandomEngine;
132
- *r = static_cast <af_random_engine>(re);
133
+ // RandomEngine contains device buffers which are dependent on
134
+ // context|stream/device. Since nor context or stream are available at
135
+ // this level, we will only use the deviceId.
136
+ thread_local std::map<int /* deviceId*/ , RandomEngine *>
137
+ cachedDefaultRandomEngines;
138
+ const int dependent = af::getDevice ();
139
+ auto it = cachedDefaultRandomEngines.find (dependent);
140
+ if (it == cachedDefaultRandomEngines.end ()) {
141
+ RandomEngine *defaultRandomEngine = new RandomEngine;
142
+ cachedDefaultRandomEngines[dependent] = defaultRandomEngine;
143
+ *r = static_cast <af_random_engine>(defaultRandomEngine);
144
+ } else {
145
+ *r = static_cast <af_random_engine>(it->second );
146
+ }
133
147
return AF_SUCCESS;
134
148
}
135
149
CATCHALL;
You can’t perform that action at this time.
0 commit comments