Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Latest commit

 

History

History
History
108 lines (94 loc) · 4.43 KB

File metadata and controls

108 lines (94 loc) · 4.43 KB
Copy raw file
Download raw file
Open symbols panel
Edit and raw actions
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#ifndef __COMMON_DIT_HPP__
#define __COMMON_DIT_HPP__
#include "ggml_extend.hpp"
namespace DiT {
inline ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int pw,
int ph,
bool patch_last = true) {
// x: [N, C, H, W]
// return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t h = H / ph;
int64_t w = W / pw;
GGML_ASSERT(h * ph == H && w * pw == W);
x = ggml_reshape_4d(ctx, x, pw, w, ph, h * C * N); // [N*C*h, ph, w, pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, ph, pw]
x = ggml_reshape_4d(ctx, x, pw * ph, w * h, C, N); // [N, C, h*w, ph*pw]
if (patch_last) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, ph*pw]
x = ggml_reshape_3d(ctx, x, pw * ph * C, w * h, N); // [N, h*w, C*ph*pw]
} else {
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, ph*pw]
x = ggml_reshape_3d(ctx, x, C * pw * ph, w * h, N); // [N, h*w, ph*pw*C]
}
return x;
}
inline ggml_tensor* unpatchify(ggml_context* ctx,
ggml_tensor* x,
int64_t h,
int64_t w,
int ph,
int pw,
bool patch_last = true) {
// x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / ph / pw;
int64_t H = h * ph;
int64_t W = w * pw;
GGML_ASSERT(C * ph * pw == x->ne[0]);
if (patch_last) {
x = ggml_reshape_4d(ctx, x, pw * ph, C, w * h, N); // [N, h*w, C, ph*pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, ph*pw]
} else {
x = ggml_reshape_4d(ctx, x, C, pw * ph, w * h, N); // [N, h*w, ph*pw, C]
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, h*w, ph*pw]
}
x = ggml_reshape_4d(ctx, x, pw, ph, w, h * C * N); // [N*C*h, w, ph, pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, ph, w, pw]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*ph, w*pw]
return x;
}
inline ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (ph - H % ph) % ph;
int pad_w = (pw - W % pw) % pw;
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
inline ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw,
bool patch_last = true) {
x = pad_to_patch_size(ctx, x, ph, pw);
x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last);
return x;
}
inline ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
ggml_tensor* x,
int64_t H,
int64_t W,
int ph,
int pw,
bool patch_last = true) {
int pad_h = (ph - H % ph) % ph;
int pad_w = (pw - W % pw) % pw;
int64_t h = ((H + pad_h) / ph);
int64_t w = ((W + pad_w) / pw);
x = unpatchify(ctx, x, h, w, ph, pw, patch_last); // [N, C, H + pad_h, W + pad_w]
x = ggml_ext_slice(ctx, x, 1, 0, H); // [N, C, H, W + pad_w]
x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W]
return x;
}
} // namespace DiT
#endif // __COMMON_DIT_HPP__
Morty Proxy This is a proxified and sanitized view of the page, visit original site.