Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 84c0920

Browse filesBrowse files
Gnurroabetlen
andauthored
feat: Add loading sharded GGUF files from HuggingFace with Llama.from_pretrained(additional_files=[...]) . Closes abetlen#1341
Co-authored-by: Andrei <abetlen@gmail.com>
1 parent 29afcfd commit 84c0920
Copy full SHA for 84c0920

File tree

Expand file treeCollapse file tree

1 file changed

+33
-0
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+33
-0
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+33Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,6 +2227,7 @@ def from_pretrained(
22272227
cls,
22282228
repo_id: str,
22292229
filename: Optional[str],
2230+
additional_files: Optional[List] = None,
22302231
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
22312232
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
22322233
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
@@ -2239,6 +2240,7 @@ def from_pretrained(
22392240
Args:
22402241
repo_id: The model repo id.
22412242
filename: A filename or glob pattern to match the model file in the repo.
2243+
additional_files: A list of filenames or glob patterns to match additional model files in the repo.
22422244
local_dir: The local directory to save the model to.
22432245
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
22442246
**kwargs: Additional keyword arguments to pass to the Llama constructor.
@@ -2269,6 +2271,7 @@ def from_pretrained(
22692271
rel_path = Path(file).relative_to(repo_id)
22702272
file_list.append(str(rel_path))
22712273

2274+
# find the only/first shard file:
22722275
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
22732276

22742277
if len(matching_files) == 0:
@@ -2298,6 +2301,35 @@ def from_pretrained(
22982301
cache_dir=cache_dir,
22992302
)
23002303

2304+
if additional_files:
2305+
for additonal_file_name in additional_files:
2306+
# find the additional shard file:
2307+
matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)]
2308+
2309+
if len(matching_additional_files) == 0:
2310+
raise ValueError(
2311+
f"No file found in {repo_id} that match {additonal_file_name}\n\n"
2312+
f"Available Files:\n{json.dumps(file_list)}"
2313+
)
2314+
2315+
if len(matching_additional_files) > 1:
2316+
raise ValueError(
2317+
f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
2318+
f"Available Files:\n{json.dumps(files)}"
2319+
)
2320+
2321+
(matching_additional_file,) = matching_additional_files
2322+
2323+
# download the additional file
2324+
hf_hub_download(
2325+
repo_id=repo_id,
2326+
filename=matching_additional_file,
2327+
subfolder=subfolder,
2328+
local_dir=local_dir,
2329+
local_dir_use_symlinks=local_dir_use_symlinks,
2330+
cache_dir=cache_dir,
2331+
)
2332+
23012333
if local_dir is None:
23022334
model_path = hf_hub_download(
23032335
repo_id=repo_id,
@@ -2311,6 +2343,7 @@ def from_pretrained(
23112343
else:
23122344
model_path = os.path.join(local_dir, filename)
23132345

2346+
# loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
23142347
return cls(
23152348
model_path=model_path,
23162349
**kwargs,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.