Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a315cac

Browse filesBrowse files
llama/ggml: add LLM training support
more compact progress bar refactor: llama_prepare_sbatch/ubatch llama_save_model_to_file gqa_mode arg for repeat_back llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt
1 parent 9c8dcef commit a315cac
Copy full SHA for a315cac
Expand file treeCollapse file tree

28 files changed

+1514
-490
lines changed

‎common/common.cpp

Copy file name to clipboardExpand all lines: common/common.cpp
+16Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,3 +2074,19 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
20742074
return result;
20752075
}
20762076

2077+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
2078+
const int64_t ne_datapoint = llama_n_ctx(ctx);
2079+
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
2080+
ggml_opt_dataset_t result = ggml_opt_dataset_init(
2081+
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
2082+
2083+
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
2084+
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
2085+
2086+
for (int64_t idata = 0; idata < ndata; ++idata) {
2087+
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
2088+
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
2089+
}
2090+
2091+
return result;
2092+
}

‎common/common.h

Copy file name to clipboardExpand all lines: common/common.h
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,3 +672,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
672672
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
673673

674674
}
675+
676+
//
677+
// training utils
678+
//
679+
680+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

‎examples/CMakeLists.txt

Copy file name to clipboardExpand all lines: examples/CMakeLists.txt
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ else()
5353
add_subdirectory(tokenize)
5454
add_subdirectory(tts)
5555
add_subdirectory(gen-docs)
56+
add_subdirectory(training)
5657
if (NOT GGML_BACKEND_DL)
5758
# these examples use the backends directly and cannot be built with dynamic loading
5859
add_subdirectory(convert-llama2c-to-ggml)

‎examples/training/CMakeLists.txt

Copy file name to clipboard
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-finetune)
2+
add_executable(${TARGET} finetune.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

‎examples/training/README.md

Copy file name to clipboard
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# llama.cpp/examples/training
2+
3+
This directory contains examples related to language model training using llama.cpp/GGML.
4+
So far finetuning is technically functional (for FP32 models) but the code is very much WIP.

‎examples/training/finetune.cpp

Copy file name to clipboard
+97Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include "arg.h"
2+
#include "common.h"
3+
#include "log.h"
4+
#include "llama.h"
5+
6+
#include <cmath>
7+
#include <cstdio>
8+
#include <cstring>
9+
#include <ctime>
10+
#include <vector>
11+
12+
#if defined(_MSC_VER)
13+
#pragma warning(disable: 4244 4267) // possible loss of data
14+
#endif
15+
16+
int main(int argc, char ** argv) {
17+
common_params params;
18+
19+
params.logits_all = true;
20+
params.escape = false;
21+
22+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
23+
return 1;
24+
}
25+
26+
if (params.use_mmap) {
27+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
28+
params.use_mmap = false;
29+
}
30+
if (params.cache_type_k == GGML_TYPE_F16) {
31+
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
32+
params.cache_type_k = GGML_TYPE_F32;
33+
}
34+
if (params.cache_type_v == GGML_TYPE_F16) {
35+
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
36+
params.cache_type_v = GGML_TYPE_F32;
37+
}
38+
39+
common_init();
40+
llama_backend_init();
41+
llama_numa_init(params.numa);
42+
43+
// load the model and apply lora adapter, if any
44+
common_init_result llama_init = common_init_from_params(params);
45+
llama_model_ptr & model = llama_init.model;
46+
llama_context_ptr & ctx = llama_init.context;
47+
48+
if (model == NULL) {
49+
LOG_ERR("%s: unable to load model\n", __func__);
50+
return 1;
51+
}
52+
53+
// print system information
54+
{
55+
LOG_INF("\n");
56+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
57+
}
58+
59+
constexpr float val_split = 0.05f;
60+
61+
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
62+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
63+
64+
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
65+
optimizer_params.adamw.alpha = 1e-7f; // learning rate
66+
67+
struct llama_opt_params lopt_params {
68+
/*n_ctx_train =*/ 0,
69+
/*param_filter =*/ llama_opt_param_filter_all,
70+
/*param_filter_ud =*/ nullptr,
71+
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
72+
/*get_opt_pars_ud =*/ &optimizer_params,
73+
};
74+
llama_opt_init(ctx.get(), model.get(), lopt_params);
75+
76+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
77+
78+
ggml_opt_result_t result_train = ggml_opt_result_init();
79+
ggml_opt_result_t result_eval = ggml_opt_result_init();
80+
81+
for (int epoch = 0; epoch < 2; ++epoch) {
82+
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
83+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
84+
fprintf(stderr, "\n");
85+
86+
ggml_opt_result_reset(result_train);
87+
ggml_opt_result_reset(result_eval);
88+
}
89+
ggml_opt_result_free(result_train);
90+
ggml_opt_result_free(result_eval);
91+
92+
llama_save_model_to_file(model.get(), "finetuned-model.gguf");
93+
94+
llama_backend_free();
95+
96+
return 0;
97+
}

‎ggml/include/ggml-opt.h

Copy file name to clipboardExpand all lines: ggml/include/ggml-opt.h
+47-28Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ extern "C" {
3737
// ====== Dataset ======
3838

3939
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
40-
int64_t ne_datapoint, // number of elements per datapoint
41-
int64_t ne_label, // number of elements per label
42-
int64_t ndata, // total number of datapoints/labels
43-
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40+
enum ggml_type type_data, // the type for the internal data tensor
41+
enum ggml_type type_label, // the type for the internal labels tensor
42+
int64_t ne_datapoint, // number of elements per datapoint
43+
int64_t ne_label, // number of elements per label
44+
int64_t ndata, // total number of datapoints/labels
45+
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
4446
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
4547

4648
// get underlying tensors that store the data
49+
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
4750
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
4851
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
4952

@@ -56,13 +59,19 @@ extern "C" {
5659
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
5760
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
5861
int64_t ibatch);
62+
GGML_API void ggml_opt_dataset_get_batch_host(
63+
ggml_opt_dataset_t dataset,
64+
void * data_batch,
65+
size_t nb_data_batch,
66+
void * labels_batch,
67+
int64_t ibatch);
5968

6069
// ====== Model / Context ======
6170

6271
enum ggml_opt_build_type {
63-
GGML_OPT_BUILD_TYPE_FORWARD,
64-
GGML_OPT_BUILD_TYPE_GRAD,
65-
GGML_OPT_BUILD_TYPE_OPT,
72+
GGML_OPT_BUILD_TYPE_FORWARD = 10,
73+
GGML_OPT_BUILD_TYPE_GRAD = 20,
74+
GGML_OPT_BUILD_TYPE_OPT = 30,
6675
};
6776

6877
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
8190
// userdata can be used to pass arbitrary data
8291
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
8392

84-
// returns the default optimizer params (constant)
93+
// returns the default optimizer params (constant, hard-coded values)
8594
// userdata is not used
8695
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
8796

97+
// casts userdata to ggml_opt_optimizer_params and returns it
98+
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
99+
88100
// parameters for initializing a new optimization context
89101
struct ggml_opt_params {
90102
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
91103

92-
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
93-
94-
// the forward graph is defined by inputs and outputs
95-
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96-
struct ggml_tensor * inputs;
97-
struct ggml_tensor * outputs;
104+
// by default the forward graph needs to be reconstructed for each eval
105+
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106+
struct ggml_context * ctx_compute;
107+
struct ggml_tensor * inputs;
108+
struct ggml_tensor * outputs;
98109

99110
enum ggml_opt_loss_type loss_type;
100111
enum ggml_opt_build_type build_type;
@@ -107,12 +118,9 @@ extern "C" {
107118

108119
// get parameters for an optimization context with defaults set where possible
109120
// parameters for which no sensible defaults exist are supplied as arguments to this function
110-
GGML_API ggml_opt_params ggml_opt_default_params(
111-
ggml_backend_sched_t backend_sched,
112-
struct ggml_context * ctx_compute,
113-
struct ggml_tensor * inputs,
114-
struct ggml_tensor * outputs,
115-
enum ggml_opt_loss_type loss_type);
121+
GGML_API struct ggml_opt_params ggml_opt_default_params(
122+
ggml_backend_sched_t backend_sched,
123+
enum ggml_opt_loss_type loss_type);
116124

117125
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
118126
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@@ -121,18 +129,20 @@ extern "C" {
121129
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
122130

123131
// get underlying tensors that store data
132+
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
124133
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
125134
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
126135
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
127136
GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
128137
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
129138
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
130139

140+
// get the gradient accumulator for a node from the forward graph
131141
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
132142

133143
// ====== Optimization Result ======
134144

135-
GGML_API ggml_opt_result_t ggml_opt_result_init();
145+
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
136146
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
137147
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
138148

@@ -144,11 +154,20 @@ extern "C" {
144154

145155
// ====== Computation ======
146156

147-
// do forward pass, increment result if not NULL
148-
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
157+
// if not using static graphs, this function must be called prior to ggml_opt_alloc
158+
GGML_API void ggml_opt_prepare_alloc(
159+
ggml_opt_context_t opt_ctx,
160+
struct ggml_context * ctx_compute,
161+
struct ggml_cgraph * gf,
162+
struct ggml_tensor * inputs,
163+
struct ggml_tensor * outputs);
164+
165+
// allocate the next graph for evaluation, either forward or forward + backward
166+
// must be called exactly once prior to calling ggml_opt_eval
167+
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
149168

150-
// do forward pass, increment result if not NULL, do backward pass
151-
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
169+
// do forward pass, increment result if not NULL, do backward pass if allocated
170+
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
152171

153172
// ############################################################################
154173
// ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
200219
// fit model defined by inputs and outputs to dataset
201220
GGML_API void ggml_opt_fit(
202221
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
203-
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
204-
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
205-
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
222+
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
223+
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
224+
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206225
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
207226
enum ggml_opt_loss_type loss_type, // loss to minimize
208227
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)

‎ggml/include/ggml.h

Copy file name to clipboardExpand all lines: ggml/include/ggml.h
+7-7Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ extern "C" {
763763
// Tensor flags
764764
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
765765
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
766-
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
766+
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
767767
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
768768

769769
//
@@ -933,7 +933,8 @@ extern "C" {
933933
GGML_API struct ggml_tensor * ggml_repeat_back(
934934
struct ggml_context * ctx,
935935
struct ggml_tensor * a,
936-
struct ggml_tensor * b);
936+
struct ggml_tensor * b,
937+
bool gqa_mode); // use memory pattern for backward pass of mat. mul. with group-query attention
937938

938939
// concat a and b along dim
939940
// used in stable-diffusion
@@ -2054,15 +2055,14 @@ extern "C" {
20542055

20552056
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
20562057
GGML_API void ggml_build_backward_expand(
2057-
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
2058-
struct ggml_context * ctx_compute, // context for gradient computation
2059-
struct ggml_cgraph * cgraph,
2060-
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
2058+
struct ggml_context * ctx, // context for gradient computation
2059+
struct ggml_cgraph * cgraph,
2060+
struct ggml_tensor ** grad_accs);
20612061

20622062
// graph allocation in a context
20632063
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
20642064
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
2065-
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2065+
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
20662066
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
20672067
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
20682068
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);

‎ggml/src/ggml-backend.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-backend.cpp
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct
333333
}
334334

335335
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
336+
// GGML_ASSERT(ggml_backend_dev_supports_op(backend->device, op));
336337
return ggml_backend_dev_supports_op(backend->device, op);
337338
}
338339

@@ -1107,7 +1108,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
11071108

11081109
const int node_backend_id = tensor_backend_id(node);
11091110

1110-
assert(node_backend_id != -1); // all nodes should be assigned by now
1111+
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
11111112

11121113
// check if we should start a new split based on the sources of the current node
11131114
bool need_new_split = false;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.