forked from leejet/stable-diffusion.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample-cache.h
More file actions
61 lines (46 loc) · 1.69 KB
/
sample-cache.h
File metadata and controls
61 lines (46 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#ifndef __SAMPLE_CACHE_H__
#define __SAMPLE_CACHE_H__
#include <vector>
#include "cache_dit.hpp"
#include "denoiser.hpp"
#include "easycache.hpp"
#include "model.h"
#include "spectrum.hpp"
#include "tensor.hpp"
#include "ucache.hpp"
#include "util.h"
namespace sd_sample {
enum class SampleCacheMode {
NONE,
EASYCACHE,
UCACHE,
CACHEDIT,
};
struct SampleCacheRuntime {
SampleCacheMode mode = SampleCacheMode::NONE;
EasyCacheState easycache;
UCacheState ucache;
CacheDitConditionState cachedit;
SpectrumState spectrum;
bool spectrum_enabled = false;
bool easycache_enabled() const;
bool ucache_enabled() const;
bool cachedit_enabled() const;
};
struct SampleStepCacheDispatcher {
SampleCacheRuntime& runtime;
int step;
float sigma;
int step_index;
SampleStepCacheDispatcher(SampleCacheRuntime& runtime, int step, float sigma);
bool before_condition(const void* condition, const sd::Tensor<float>& input, sd::Tensor<float>* output);
void after_condition(const void* condition, const sd::Tensor<float>& input, const sd::Tensor<float>& output);
bool is_step_skipped() const;
};
SampleCacheRuntime init_sample_cache_runtime(SDVersion version,
const sd_cache_params_t* cache_params,
Denoiser* denoiser,
const std::vector<float>& sigmas);
void log_sample_cache_summary(const SampleCacheRuntime& runtime, size_t total_steps);
} // namespace sd_sample
#endif // __SAMPLE_CACHE_H__