Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bffdaf4

Browse filesBrowse files
committed
Merge branch 'master' into compilade/lazy-convert-hf
2 parents 94e667a + 83330d8 commit bffdaf4
Copy full SHA for bffdaf4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Dismiss banner
Expand file treeCollapse file tree

43 files changed

+1719
-252
lines changed

‎CMakeLists.txt

Copy file name to clipboardExpand all lines: CMakeLists.txt
+10-1Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
103103
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
104104
"llama: max. batch size for using peer access")
105105
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
106+
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)
107+
106108
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
107109
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
108110
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
@@ -409,6 +411,9 @@ if (LLAMA_CUDA)
409411
if (LLAMA_CUDA_FORCE_MMQ)
410412
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
411413
endif()
414+
if (LLAMA_CUDA_NO_VMM)
415+
add_compile_definitions(GGML_CUDA_NO_VMM)
416+
endif()
412417
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
413418
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
414419
if (DEFINED LLAMA_CUDA_DMMV_Y)
@@ -434,7 +439,11 @@ if (LLAMA_CUDA)
434439
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
435440
endif()
436441

437-
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
442+
if (LLAMA_CUDA_NO_VMM)
443+
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
444+
else()
445+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
446+
endif()
438447

439448
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
440449
# 52 == lowest CUDA 12 standard

‎README.md

Copy file name to clipboardExpand all lines: README.md
+20-24Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
2020

2121
### Hot topics
2222

23-
- **BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920**
23+
- **Initial Flash-Attention support: https://github.com/ggerganov/llama.cpp/pull/5021**
24+
- BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920
2425
- MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387
2526
- Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404
2627
- Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225
@@ -935,25 +936,35 @@ If your issue is with model generation quality, then please at least scan the fo
935936
936937
### Android
937938
938-
#### Building the Project using Android NDK
939-
You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/).
939+
#### Build on Android using Termux
940+
[Termux](https://github.com/termux/termux-app#installation) is a method to execute `llama.cpp` on an Android device (no root required).
941+
```
942+
apt update && apt upgrade -y
943+
apt install git make cmake
944+
```
940945
941-
First, install the essential packages for termux:
946+
It's recommended to move your model inside the `~/` directory for best performance:
942947
```
943-
pkg install clang wget git cmake
948+
cd storage/downloads
949+
mv model.gguf ~/
944950
```
945-
Second, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
946951
947-
You can execute the following commands on your computer to avoid downloading the NDK to your mobile. Of course, you can also do this in Termux.
952+
[Get the code](https://github.com/ggerganov/llama.cpp#get-the-code) & [follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
953+
954+
#### Building the Project using Android NDK
955+
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
948956
957+
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
949958
```
950959
$ mkdir build-android
951960
$ cd build-android
952961
$ export NDK=<your_ndk_directory>
953962
$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
954963
$ make
955964
```
956-
Install [termux](https://termux.dev/) on your device and run `termux-setup-storage` to get access to your SD card.
965+
966+
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
967+
957968
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
958969
959970
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
@@ -975,25 +986,10 @@ $cd /data/data/com.termux/files/home/bin
975986
$./main -m ../model/llama-2-7b-chat.Q4_K_M.gguf -n 128 -cml
976987
```
977988
978-
Here is a demo of an interactive session running on Pixel 5 phone:
989+
Here's a demo of an interactive session running on Pixel 5 phone:
979990
980991
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
981992
982-
#### Build on Android using Termux
983-
[Termux](https://github.com/termux/termux-app#installation) is an alternative to execute `llama.cpp` on an Android device (no root required).
984-
```
985-
apt update && apt upgrade -y
986-
apt install git
987-
```
988-
989-
It's recommended to move your model inside the `~/` directory for best performance:
990-
```
991-
cd storage/downloads
992-
mv model.gguf ~/
993-
```
994-
995-
[Follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
996-
997993
### Docker
998994
999995
#### Prerequisites

‎ci/run.sh

Copy file name to clipboardExpand all lines: ci/run.sh
+6-5Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,8 @@ function gg_run_test_scripts_debug {
160160

161161
set -e
162162

163-
# TODO: too slow, run on dedicated node
164-
#(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
165-
#(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
163+
(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
164+
(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
166165

167166
set +e
168167
}
@@ -695,8 +694,10 @@ test $ret -eq 0 && gg_run ctest_release
695694
if [ -z ${GG_BUILD_LOW_PERF} ]; then
696695
test $ret -eq 0 && gg_run embd_bge_small
697696

698-
test $ret -eq 0 && gg_run test_scripts_debug
699-
test $ret -eq 0 && gg_run test_scripts_release
697+
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
698+
test $ret -eq 0 && gg_run test_scripts_debug
699+
test $ret -eq 0 && gg_run test_scripts_release
700+
fi
700701

701702
if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then
702703
if [ -z ${GG_BUILD_CUDA} ]; then

‎common/common.cpp

Copy file name to clipboardExpand all lines: common/common.cpp
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
911911
params.instruct = true;
912912
return true;
913913
}
914+
if (arg == "-cnv" || arg == "--conversation") {
915+
params.conversation = true;
916+
return true;
917+
}
914918
if (arg == "-cml" || arg == "--chatml") {
915919
params.chatml = true;
916920
return true;
@@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
14171421
printf(" --version show version and build info\n");
14181422
printf(" -i, --interactive run in interactive mode\n");
14191423
printf(" --interactive-first run in interactive mode and wait for input right away\n");
1424+
printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n");
14201425
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
14211426
printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
14221427
printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");

‎common/common.h

Copy file name to clipboardExpand all lines: common/common.h
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ struct gpt_params {
140140
bool random_prompt = false; // do not randomize prompt if none provided
141141
bool use_color = false; // use color to distinguish generations and inputs
142142
bool interactive = false; // interactive mode
143+
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
143144
bool chatml = false; // chatml mode (used for models trained on chatml syntax)
144145
bool prompt_cache_all = false; // save user input and generations to prompt cache
145146
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

‎common/sampling.cpp

Copy file name to clipboardExpand all lines: common/sampling.cpp
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
3535

3636
result->prev.resize(params.n_prev);
3737

38+
result->n_considered = 0;
39+
3840
llama_sampling_set_rng_seed(result, params.seed);
3941

4042
return result;
@@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6466

6567
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
6668
ctx->cur.clear();
69+
ctx->n_considered = 0;
6770
}
6871

6972
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
253256
}
254257
}
255258

259+
ctx_sampling->n_considered = cur_p.size;
260+
256261
return id;
257262
}
258263

‎common/sampling.h

Copy file name to clipboardExpand all lines: common/sampling.h
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ struct llama_sampling_context {
8181
// TODO: replace with ring-buffer
8282
std::vector<llama_token> prev;
8383
std::vector<llama_token_data> cur;
84+
size_t n_considered;
8485

8586
std::mt19937 rng;
8687
};

‎convert-hf-to-gguf-update.py

Copy file name to clipboardExpand all lines: convert-hf-to-gguf-update.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class TOKENIZER_TYPE(IntEnum):
6767
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
6868
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
6969
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
70+
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
71+
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
72+
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
7073
]
7174

7275
# make directory "models/tokenizers" if it doesn't exist
@@ -150,6 +153,8 @@ def download_file_with_auth(url, token, save_path):
150153
# print the "pre_tokenizer" content from the tokenizer.json
151154
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
152155
cfg = json.load(f)
156+
normalizer = cfg["normalizer"]
157+
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
153158
pre_tokenizer = cfg["pre_tokenizer"]
154159
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
155160

‎convert-hf-to-gguf.py

Copy file name to clipboardExpand all lines: convert-hf-to-gguf.py
+12-2Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,15 @@ def get_vocab_base_pre(self, tokenizer) -> str:
397397
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
398398
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
399399
res = "command-r"
400+
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
401+
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
402+
res = "qwen2"
403+
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
404+
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
405+
res = "olmo"
406+
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
407+
# ref: https://huggingface.co/databricks/dbrx-instruct
408+
res = "dbrx"
400409

401410
if res is None:
402411
logger.warning("\n")
@@ -2248,8 +2257,9 @@ class OlmoModel(Model):
22482257
def set_gguf_parameters(self):
22492258
super().set_gguf_parameters()
22502259
self.gguf_writer.add_layer_norm_eps(1e-5)
2251-
if "clip_qkv" in self.hparams is not None:
2252-
self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"])
2260+
clip_qkv = self.hparams.get("clip_qkv")
2261+
if clip_qkv is not None:
2262+
self.gguf_writer.add_clamp_kqv(clip_qkv)
22532263

22542264
# Same as super class, but permuting q_proj, k_proj
22552265
# Copied from: LlamaModel

‎convert.py

Copy file name to clipboardExpand all lines: convert.py
+32-19Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,25 +1512,27 @@ def main(args_in: list[str] | None = None) -> None:
15121512
if args.big_endian:
15131513
endianess = gguf.GGUFEndian.BIG
15141514

1515-
params = Params.load(model_plus)
1516-
if params.n_ctx == -1:
1517-
if args.ctx is None:
1518-
msg = """\
1519-
The model doesn't have a context size, and you didn't specify one with --ctx
1520-
Please specify one with --ctx:
1521-
- LLaMA v1: --ctx 2048
1522-
- LLaMA v2: --ctx 4096"""
1523-
parser.error(textwrap.dedent(msg))
1524-
params.n_ctx = args.ctx
1525-
1526-
if args.outtype:
1527-
params.ftype = {
1528-
"f32": GGMLFileType.AllF32,
1529-
"f16": GGMLFileType.MostlyF16,
1530-
"q8_0": GGMLFileType.MostlyQ8_0,
1531-
}[args.outtype]
1532-
1533-
logger.info(f"params = {params}")
1515+
params = None
1516+
if args.pad_vocab or not args.vocab_only:
1517+
params = Params.load(model_plus)
1518+
if params.n_ctx == -1:
1519+
if args.ctx is None:
1520+
msg = """\
1521+
The model doesn't have a context size, and you didn't specify one with --ctx
1522+
Please specify one with --ctx:
1523+
- LLaMA v1: --ctx 2048
1524+
- LLaMA v2: --ctx 4096"""
1525+
parser.error(textwrap.dedent(msg))
1526+
params.n_ctx = args.ctx
1527+
1528+
if args.outtype:
1529+
params.ftype = {
1530+
"f32": GGMLFileType.AllF32,
1531+
"f16": GGMLFileType.MostlyF16,
1532+
"q8_0": GGMLFileType.MostlyQ8_0,
1533+
}[args.outtype]
1534+
1535+
logger.info(f"params = {params}")
15341536

15351537
model_parent_path = model_plus.paths[0].parent
15361538
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
@@ -1543,6 +1545,17 @@ def main(args_in: list[str] | None = None) -> None:
15431545
if not args.outfile:
15441546
raise ValueError("need --outfile if using --vocab-only")
15451547
outfile = args.outfile
1548+
if params is None:
1549+
params = Params(
1550+
n_vocab = vocab.vocab_size,
1551+
n_embd = 1,
1552+
n_layer = 1,
1553+
n_ctx = 1,
1554+
n_ff = 1,
1555+
n_head = 1,
1556+
n_head_kv = 1,
1557+
f_norm_eps = 1e-5,
1558+
)
15461559
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
15471560
endianess=endianess, pad_vocab=args.pad_vocab)
15481561
logger.info(f"Wrote {outfile}")

‎docs/BLIS.md

Copy file name to clipboardExpand all lines: docs/BLIS.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Install BLIS:
2323
sudo make install
2424
```
2525

26-
We recommend using openmp since it's easier to modify the cores been used.
26+
We recommend using openmp since it's easier to modify the cores being used.
2727

2828
### llama.cpp compilation
2929

‎docs/HOWTO-add-model.md

Copy file name to clipboardExpand all lines: docs/HOWTO-add-model.md
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc
9696

9797
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
9898

99-
Have a look to existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
99+
Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
100100

101-
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support of missing backend operations can be added in another PR.
101+
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
102102

103103
Note: to debug the inference graph: you can use [eval-callback](../examples/eval-callback).
104104

‎examples/finetune/finetune.cpp

Copy file name to clipboardExpand all lines: examples/finetune/finetune.cpp
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
575575
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
576576

577577
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
578-
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
578+
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) {
579579
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
580580
} else if (a->type == GGML_TYPE_F32) {
581581
return ggml_add(ctx, a, b);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.