Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 443e7e7

Browse filesBrowse files
committed
Merge branch 'mamba2-sync' into GraniteFour
* mamba2-sync: (24 commits) sync : ggml Add `ggml_roll` (ggml/1274) docs : fix the link to llama.h (ggml-org#14293) CUDA: add conv_2d_transpose (ggml-org#14287) lint : remove trailing whitepace (ggml-org#14304) vocab : prevent tokenizer overflow (ggml-org#14301) sycl: add usage of enqueue_functions extension (ggml-org#14244) Implement GGML_CPU_ALL_VARIANTS for PowerPC (ggml-org#14286) llama : improve sep token handling (ggml-org#14272) cuda : synchronize graph capture and cublas handle destruction (ggml-org#14288) ggml : fix repack work size for mul_mat_id (ggml-org#14292) ggml: Update KleidiAI to v1.9.0 (ggml-org#14277) model : more uniform output id handling (ggml-org#14275) ubatch : new splitting logic (ggml-org#14217) CUDA: add conv_2d_dw (ggml-org#14265) ggml-cpu : remove unnecesary arm feature detection (ggml-org#14281) gguf-py : make sentencepiece optional (ggml-org#14200) server : add server parameters for draft model cache type (ggml-org#13782) build : suppress gcc15 compile warnings (ggml-org#14261) sycl: Cleanup codepaths in Get Rows in sycl backend (ggml-org#14215) ...
2 parents 8f3af99 + b605bb9 commit 443e7e7
Copy full SHA for 443e7e7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Dismiss banner
Expand file treeCollapse file tree

82 files changed

+4212
-3693
lines changed

‎ci/run.sh

Copy file name to clipboardExpand all lines: ci/run.sh
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ function gg_run_rerank_tiny {
779779
model_f16="${path_models}/ggml-model-f16.gguf"
780780

781781
# for this model, the SEP token is "</s>"
782-
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s></s>hi\nwhat is panda?</s></s>it's a bear\nwhat is panda?</s></s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
782+
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
783783

784784
# sample output
785785
# rerank score 0: 0.029

‎common/arg.cpp

Copy file name to clipboardExpand all lines: common/arg.cpp
+33Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,6 +2706,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27062706
params.embd_sep = value;
27072707
}
27082708
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
2709+
add_opt(common_arg(
2710+
{"--cls-separator"}, "STRING",
2711+
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
2712+
[](common_params & params, const std::string & value) {
2713+
params.cls_sep = value;
2714+
}
2715+
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
27092716
add_opt(common_arg(
27102717
{"--host"}, "HOST",
27112718
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),
@@ -3210,6 +3217,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32103217
params.speculative.model.path = value;
32113218
}
32123219
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
3220+
add_opt(common_arg(
3221+
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
3222+
string_format(
3223+
"KV cache data type for K for the draft model\n"
3224+
"allowed values: %s\n"
3225+
"(default: %s)",
3226+
get_all_kv_cache_types().c_str(),
3227+
ggml_type_name(params.speculative.cache_type_k)
3228+
),
3229+
[](common_params & params, const std::string & value) {
3230+
params.speculative.cache_type_k = kv_cache_type_from_str(value);
3231+
}
3232+
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
3233+
add_opt(common_arg(
3234+
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
3235+
string_format(
3236+
"KV cache data type for V for the draft model\n"
3237+
"allowed values: %s\n"
3238+
"(default: %s)",
3239+
get_all_kv_cache_types().c_str(),
3240+
ggml_type_name(params.speculative.cache_type_v)
3241+
),
3242+
[](common_params & params, const std::string & value) {
3243+
params.speculative.cache_type_v = kv_cache_type_from_str(value);
3244+
}
3245+
).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT"));
32133246

32143247
add_opt(common_arg(
32153248
{"-mv", "--model-vocoder"}, "FNAME",

‎common/common.cpp

Copy file name to clipboardExpand all lines: common/common.cpp
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,17 @@ bool fs_validate_filename(const std::string & filename) {
706706
// disable C++17 deprecation warning for std::codecvt_utf8
707707
# pragma clang diagnostic push
708708
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
709+
#elif defined(__GNUC__)
710+
# pragma GCC diagnostic push
711+
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
709712
#endif
713+
710714
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
711715

712716
#if defined(__clang__)
713717
# pragma clang diagnostic pop
718+
#elif defined(__GNUC__)
719+
# pragma GCC diagnostic pop
714720
#endif
715721

716722
filename_utf32 = converter.from_bytes(filename);
@@ -1284,6 +1290,9 @@ std::vector<llama_token> common_tokenize(
12841290
int n_tokens = text.length() + 2 * add_special;
12851291
std::vector<llama_token> result(n_tokens);
12861292
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1293+
if (n_tokens == std::numeric_limits<int32_t>::min()) {
1294+
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
1295+
}
12871296
if (n_tokens < 0) {
12881297
result.resize(-n_tokens);
12891298
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);

‎common/common.h

Copy file name to clipboardExpand all lines: common/common.h
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ struct common_params_speculative {
199199
float p_split = 0.1f; // speculative decoding split probability
200200
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
201201

202+
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
203+
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
204+
202205
struct cpu_params cpuparams;
203206
struct cpu_params cpuparams_batch;
204207

@@ -355,6 +358,7 @@ struct common_params {
355358
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
356359
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
357360
std::string embd_sep = "\n"; // separator of embeddings
361+
std::string cls_sep = "\t"; // separator of classification sequences
358362

359363
// server params
360364
int32_t port = 8080; // server listens on this network port

‎convert_hf_to_gguf.py

Copy file name to clipboardExpand all lines: convert_hf_to_gguf.py
+11-23Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,6 @@ def __init__(self, *args, **kwargs):
21452145

21462146
def set_vocab(self):
21472147
self._set_vocab_gpt2()
2148-
self.gguf_writer.add_add_bos_token(True)
21492148

21502149
def set_gguf_parameters(self):
21512150
super().set_gguf_parameters()
@@ -3918,9 +3917,6 @@ def _xlmroberta_set_vocab(self) -> None:
39183917
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
39193918
special_vocab.add_to_gguf(self.gguf_writer)
39203919

3921-
self.gguf_writer.add_add_bos_token(True)
3922-
self.gguf_writer.add_add_eos_token(True)
3923-
39243920

39253921
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
39263922
class DistilBertModel(BertModel):
@@ -3962,8 +3958,6 @@ def set_vocab(self):
39623958
bpe_tok_path = self.dir_model / "tokenizer.json"
39633959
if bpe_tok_path.exists():
39643960
self._set_vocab_gpt2()
3965-
self.gguf_writer.add_add_bos_token(True)
3966-
self.gguf_writer.add_add_eos_token(True)
39673961

39683962
# we need this to validate the size of the token_type embeddings
39693963
# though currently we are passing all zeros to the token_type embeddings
@@ -5056,8 +5050,6 @@ def set_vocab(self):
50565050
self.gguf_writer.add_token_type_count(2)
50575051
else:
50585052
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
5059-
self.gguf_writer.add_add_bos_token(True)
5060-
self.gguf_writer.add_add_eos_token(True)
50615053

50625054

50635055
@ModelBase.register("OpenELMForCausalLM")
@@ -5659,9 +5651,6 @@ def set_vocab(self):
56595651
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
56605652
special_vocab.add_to_gguf(self.gguf_writer)
56615653

5662-
self.gguf_writer.add_add_bos_token(False)
5663-
self.gguf_writer.add_add_eos_token(True)
5664-
56655654
def set_gguf_parameters(self):
56665655
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
56675656
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -5799,9 +5788,6 @@ def set_vocab(self):
57995788
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
58005789
special_vocab.add_to_gguf(self.gguf_writer)
58015790

5802-
self.gguf_writer.add_add_bos_token(False)
5803-
self.gguf_writer.add_add_eos_token(True)
5804-
58055791
def set_gguf_parameters(self):
58065792
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
58075793
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -6630,8 +6616,8 @@ def parse_args() -> argparse.Namespace:
66306616
help="model is executed on big endian machine",
66316617
)
66326618
parser.add_argument(
6633-
"model", type=Path,
6634-
help="directory containing model file",
6619+
"model", type=str,
6620+
help="directory containing model file or huggingface repository ID (if --remote)",
66356621
nargs="?",
66366622
)
66376623
parser.add_argument(
@@ -6742,18 +6728,20 @@ def main() -> None:
67426728
else:
67436729
logging.basicConfig(level=logging.INFO)
67446730

6745-
dir_model = args.model
6746-
67476731
if args.remote:
6732+
hf_repo_id = args.model
67486733
from huggingface_hub import snapshot_download
67496734
local_dir = snapshot_download(
6750-
repo_id=str(dir_model),
6735+
repo_id=hf_repo_id,
67516736
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
67526737
dir_model = Path(local_dir)
67536738
logger.info(f"Downloaded config and tokenizer to {local_dir}")
6739+
else:
6740+
hf_repo_id = None
6741+
dir_model = Path(args.model)
67546742

67556743
if not dir_model.is_dir():
6756-
logger.error(f'Error: {args.model} is not a directory')
6744+
logger.error(f'Error: {dir_model} is not a directory')
67576745
sys.exit(1)
67586746

67596747
ftype_map: dict[str, gguf.LlamaFileType] = {
@@ -6773,9 +6761,9 @@ def main() -> None:
67736761

67746762
if args.outfile is not None:
67756763
fname_out = args.outfile
6776-
elif args.remote:
6764+
elif hf_repo_id:
67776765
# if remote, use the model ID as the output file name
6778-
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
6766+
fname_out = Path("./" + hf_repo_id.replace("/", "-") + "-{ftype}.gguf")
67796767
else:
67806768
fname_out = dir_model
67816769

@@ -6804,7 +6792,7 @@ def main() -> None:
68046792
split_max_tensors=args.split_max_tensors,
68056793
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
68066794
small_first_shard=args.no_tensor_first_split,
6807-
remote_hf_model_id=str(args.model) if args.remote else None)
6795+
remote_hf_model_id=hf_repo_id)
68086796

68096797
if args.vocab_only:
68106798
logger.info("Exporting model vocab...")

‎docs/build.md

Copy file name to clipboardExpand all lines: docs/build.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Build llama.cpp locally
22

3-
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h).
3+
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
44

55
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
66

‎examples/embedding/embedding.cpp

Copy file name to clipboardExpand all lines: examples/embedding/embedding.cpp
+30-4Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
133133
// max batch size
134134
const uint64_t n_batch = params.n_batch;
135135

136+
// get added sep and eos token, if any
137+
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
138+
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
139+
136140
// tokenize the prompts and trim
137141
std::vector<std::vector<int32_t>> inputs;
138142
for (const auto & prompt : prompts) {
139-
auto inp = common_tokenize(ctx, prompt, true, true);
143+
std::vector<llama_token> inp;
144+
145+
// split classification pairs and insert expected separator tokens
146+
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
147+
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
148+
std::string final_prompt;
149+
150+
for (size_t i = 0; i < pairs.size(); i++) {
151+
final_prompt += pairs[i];
152+
if (i != pairs.size() - 1) {
153+
if (!added_eos_token.empty()) {
154+
final_prompt += added_eos_token;
155+
}
156+
if (!added_sep_token.empty()) {
157+
final_prompt += added_sep_token;
158+
}
159+
}
160+
}
161+
162+
inp = common_tokenize(ctx, final_prompt, true, true);
163+
} else {
164+
inp = common_tokenize(ctx, prompt, true, true);
165+
}
140166
if (inp.size() > n_batch) {
141167
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
142168
__func__, (long long int) inp.size(), (long long int) n_batch);
@@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
145171
inputs.push_back(inp);
146172
}
147173

148-
// check if the last token is SEP
174+
// check if the last token is SEP/EOS
149175
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
150176
for (auto & inp : inputs) {
151-
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
152-
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
177+
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
178+
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
153179
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
154180
}
155181
}

‎ggml/include/ggml.h

Copy file name to clipboardExpand all lines: ggml/include/ggml.h
+12Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ extern "C" {
489489
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492+
GGML_OP_ROLL,
492493
GGML_OP_ARANGE,
493494
GGML_OP_TIMESTEP_EMBEDDING,
494495
GGML_OP_ARGSORT,
@@ -1801,6 +1802,17 @@ extern "C" {
18011802
int p0,
18021803
int p1);
18031804

1805+
// Move tensor elements by an offset given for each dimension. Elements that
1806+
// are shifted beyond the last position are wrapped around to the beginning.
1807+
GGML_API struct ggml_tensor * ggml_roll(
1808+
struct ggml_context * ctx,
1809+
struct ggml_tensor * a,
1810+
int shift0,
1811+
int shift1,
1812+
int shift2,
1813+
int shift3);
1814+
1815+
18041816
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
18051817
// timesteps: [N,]
18061818
// return: [N, dim]

‎ggml/src/CMakeLists.txt

Copy file name to clipboardExpand all lines: ggml/src/CMakeLists.txt
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ function(ggml_add_cpu_backend_variant tag_name)
286286
foreach (feat ${ARGN})
287287
set(GGML_INTERNAL_${feat} ON)
288288
endforeach()
289+
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
290+
foreach (feat ${ARGN})
291+
set(GGML_INTERNAL_${feat} ON)
292+
endforeach()
289293
endif()
290294

291295
ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -337,6 +341,19 @@ if (GGML_CPU_ALL_VARIANTS)
337341
else()
338342
message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
339343
endif()
344+
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
345+
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
346+
ggml_add_cpu_backend_variant(power0)
347+
ggml_add_cpu_backend_variant(power7_1 POWER7)
348+
ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
349+
ggml_add_cpu_backend_variant(power8_1 POWER8)
350+
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
351+
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
352+
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
353+
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
354+
else()
355+
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
356+
endif()
340357
else()
341358
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
342359
endif()

‎ggml/src/ggml-backend-reg.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-backend-reg.cpp
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@
6969
#if defined(__clang__)
7070
# pragma clang diagnostic push
7171
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
72+
#elif defined(__GNUC__)
73+
# pragma GCC diagnostic push
74+
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
7275
#endif
7376

7477
namespace fs = std::filesystem;
@@ -91,6 +94,8 @@ static std::string path_str(const fs::path & path) {
9194

9295
#if defined(__clang__)
9396
# pragma clang diagnostic pop
97+
#elif defined(__GNUC__)
98+
# pragma GCC diagnostic pop
9499
#endif
95100

96101
#ifdef _WIN32

‎ggml/src/ggml-cpu/CMakeLists.txt

Copy file name to clipboardExpand all lines: ggml/src/ggml-cpu/CMakeLists.txt
+23-2Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
388388
else()
389389
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
390390
endif()
391+
elseif(GGML_CPU_ALL_VARIANTS)
392+
# Begin with the lowest baseline
393+
set(ARCH_DEFINITIONS "")
394+
395+
# When a feature is selected, bump the MCPU to the first
396+
# version that supported it
397+
foreach(PVER RANGE 7 11)
398+
if(DEFINED GGML_INTERNAL_POWER${PVER})
399+
set(POWERPC_MCPU "power${PVER}")
400+
list(APPEND ARCH_DEFINITIONS GGML_USE_POWER${PVER})
401+
endif()
402+
endforeach()
403+
if (GGML_INTERNAL_VSX)
404+
list(APPEND ARCH_DEFINITIONS GGML_USE_VSX)
405+
list(APPEND ARCH_FLAGS -mvsx)
406+
endif()
407+
408+
if (DEFINED POWERPC_MCPU)
409+
list(APPEND ARCH_FLAGS -mcpu=${POWERPC_MCPU})
410+
endif()
411+
ggml_add_cpu_backend_features(${GGML_CPU_NAME} powerpc ${ARCH_DEFINITIONS})
391412
else()
392413
if (GGML_CPU_POWERPC_CPUTYPE)
393414
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
@@ -465,9 +486,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
465486

466487
# Fetch KleidiAI sources:
467488
include(FetchContent)
468-
set(KLEIDIAI_COMMIT_TAG "v1.6.0")
489+
set(KLEIDIAI_COMMIT_TAG "v1.9.0")
469490
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
470-
set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f")
491+
set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
471492

472493
if (POLICY CMP0135)
473494
cmake_policy(SET CMP0135 NEW)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.