Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6bfe98b

Browse filesBrowse files
authored
Integration of Jinja2 Templating (abetlen#875)
* feat: Add support for jinja templating Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * fix: Refactor chat formatter and update interface for jinja templates - Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality. - Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility. - Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string. - Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation. These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage. * Add outline for Jinja2 templating integration documentation Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Add jinja2 as a dependency with version range for Hugging Face transformers compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Update jinja2 version constraint for mkdocs-material compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Fix attribute name in AutoChatFormatter - Changed attribute name from `self._renderer` to `self._environment` --------- Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
1 parent 52adc23 commit 6bfe98b
Copy full SHA for 6bfe98b

File tree

Expand file treeCollapse file tree

4 files changed

+243
-1
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+243
-1
lines changed

‎docs/templates.md

Copy file name to clipboard
+52Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Templates
2+
3+
This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.
4+
5+
## Introduction
6+
7+
- Brief explanation of the `llama-cpp-python` project's need for a templating system.
8+
- Overview of the `llama-2` model's interaction with templating.
9+
10+
## Jinja2 Dependency Integration
11+
12+
- Rationale for choosing Jinja2 as the templating engine.
13+
- Compatibility with Hugging Face's `transformers`.
14+
- Desire for advanced templating features and simplicity.
15+
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.
16+
17+
## Template Management Refactor
18+
19+
- Summary of the refactor and the motivation behind it.
20+
- Description of the new chat handler selection logic:
21+
1. Preference for a user-specified `chat_handler`.
22+
2. Fallback to a user-specified `chat_format`.
23+
3. Defaulting to a chat format from a `.gguf` file if available.
24+
4. Utilizing the `llama2` default chat format as the final fallback.
25+
- Ensuring backward compatibility throughout the refactor.
26+
27+
## Implementation Details
28+
29+
- In-depth look at the new `AutoChatFormatter` class.
30+
- Example code snippets showing how to utilize the Jinja2 environment and templates.
31+
- Guidance on how to provide custom templates or use defaults.
32+
33+
## Testing and Validation
34+
35+
- Outline of the testing strategy to ensure seamless integration.
36+
- Steps for validating backward compatibility with existing implementations.
37+
38+
## Benefits and Impact
39+
40+
- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
41+
- Discussion of the potential impact on current users and contributors.
42+
43+
## Future Work
44+
45+
- Exploration of how templating can evolve within the project.
46+
- Consideration of additional features or optimizations for the templating engine.
47+
- Mechanisms for community feedback on the templating system.
48+
49+
## Conclusion
50+
51+
- Final thoughts on the integration of Jinja2 templating.
52+
- Call to action for community involvement and feedback.

‎llama_cpp/llama_jinja_format.py

Copy file name to clipboard
+138Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
llama_cpp/llama_jinja_format.py
3+
"""
4+
import dataclasses
5+
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
6+
7+
import jinja2
8+
from jinja2 import Template
9+
10+
# NOTE: We sacrifice readability for usability.
11+
# It will fail to work as expected if we attempt to format it in a readable way.
12+
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
13+
14+
15+
class MetaSingleton(type):
16+
"""
17+
Metaclass for implementing the Singleton pattern.
18+
"""
19+
20+
_instances = {}
21+
22+
def __call__(cls, *args, **kwargs):
23+
if cls not in cls._instances:
24+
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
25+
return cls._instances[cls]
26+
27+
28+
class Singleton(object, metaclass=MetaSingleton):
29+
"""
30+
Base class for implementing the Singleton pattern.
31+
"""
32+
33+
def __init__(self):
34+
super(Singleton, self).__init__()
35+
36+
37+
@dataclasses.dataclass
38+
class ChatFormatterResponse:
39+
prompt: str
40+
stop: Optional[Union[str, List[str]]] = None
41+
42+
43+
# Base Chat Formatter Protocol
44+
class ChatFormatterInterface(Protocol):
45+
def __init__(self, template: Optional[object] = None):
46+
...
47+
48+
def __call__(
49+
self,
50+
messages: List[Dict[str, str]],
51+
**kwargs,
52+
) -> ChatFormatterResponse:
53+
...
54+
55+
@property
56+
def template(self) -> str:
57+
...
58+
59+
60+
class AutoChatFormatter(ChatFormatterInterface):
61+
def __init__(
62+
self,
63+
template: Optional[str] = None,
64+
template_class: Optional[Template] = None,
65+
):
66+
if template is not None:
67+
self._template = template
68+
else:
69+
self._template = llama2_template # default template
70+
71+
self._environment = jinja2.Environment(
72+
loader=jinja2.BaseLoader(),
73+
trim_blocks=True,
74+
lstrip_blocks=True,
75+
).from_string(
76+
self._template,
77+
template_class=template_class,
78+
)
79+
80+
def __call__(
81+
self,
82+
messages: List[Dict[str, str]],
83+
**kwargs: Any,
84+
) -> ChatFormatterResponse:
85+
formatted_sequence = self._environment.render(messages=messages, **kwargs)
86+
return ChatFormatterResponse(prompt=formatted_sequence)
87+
88+
@property
89+
def template(self) -> str:
90+
return self._template
91+
92+
93+
class FormatterNotFoundException(Exception):
94+
pass
95+
96+
97+
class ChatFormatterFactory(Singleton):
98+
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
99+
100+
def register_formatter(
101+
self,
102+
name: str,
103+
formatter_callable: Callable[[], ChatFormatterInterface],
104+
overwrite=False,
105+
):
106+
if not overwrite and name in self._chat_formatters:
107+
raise ValueError(
108+
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
109+
)
110+
self._chat_formatters[name] = formatter_callable
111+
112+
def unregister_formatter(self, name: str):
113+
if name in self._chat_formatters:
114+
del self._chat_formatters[name]
115+
else:
116+
raise ValueError(f"No formatter registered under the name '{name}'.")
117+
118+
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
119+
try:
120+
formatter_callable = self._chat_formatters[name]
121+
return formatter_callable()
122+
except KeyError:
123+
raise FormatterNotFoundException(
124+
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
125+
)
126+
127+
128+
# Define a chat format class
129+
class Llama2Formatter(AutoChatFormatter):
130+
def __init__(self):
131+
super().__init__(llama2_template)
132+
133+
134+
# With the Singleton pattern applied, regardless of where or how many times
135+
# ChatFormatterFactory() is called, it will always return the same instance
136+
# of the factory, ensuring that the factory's state is consistent throughout
137+
# the application.
138+
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)

‎pyproject.toml

Copy file name to clipboardExpand all lines: pyproject.toml
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ license = { text = "MIT" }
1111
authors = [
1212
{ name = "Andrei Betlen", email = "abetlen@gmail.com" },
1313
]
14+
# mkdocs-martiral requires "jinja2~=3.0"
15+
# transformers requires "jinja2>=2.11.3"
1416
dependencies = [
1517
"typing-extensions>=4.5.0",
1618
"numpy>=1.20.0",
1719
"diskcache>=5.6.1",
20+
"jinja2>=2.11.3",
1821
]
1922
requires-python = ">=3.8"
2023
classifiers = [
@@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
7275

7376
[tool.pytest.ini_options]
7477
addopts = "--ignore=vendor"
75-

‎tests/test_llama_chat_format.py

Copy file name to clipboard
+50Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from llama_cpp import ChatCompletionMessage
6+
from llama_cpp.llama_jinja_format import Llama2Formatter
7+
8+
9+
@pytest.fixture
10+
def sequence_of_messages() -> List[ChatCompletionMessage]:
11+
return [
12+
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
13+
ChatCompletionMessage(
14+
role="user", content="Hi there! I need some help with Python."
15+
),
16+
ChatCompletionMessage(
17+
role="assistant", content="Of course! What do you need help with in Python?"
18+
),
19+
ChatCompletionMessage(
20+
role="user",
21+
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
22+
),
23+
ChatCompletionMessage(
24+
role="assistant",
25+
content="I can help with that! Would you like a recursive or iterative solution?",
26+
),
27+
ChatCompletionMessage(
28+
role="user", content="Let's go with a recursive solution."
29+
),
30+
]
31+
32+
33+
def test_llama2_formatter(sequence_of_messages):
34+
expected_prompt = (
35+
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
36+
"[INST] Hi there! I need some help with Python. [/INST]\n"
37+
"Of course! What do you need help with in Python?\n"
38+
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
39+
"I can help with that! Would you like a recursive or iterative solution?\n"
40+
"[INST] Let's go with a recursive solution. [/INST]\n"
41+
)
42+
43+
llama2_formatter_instance = Llama2Formatter()
44+
formatter_response = llama2_formatter_instance(sequence_of_messages)
45+
assert (
46+
expected_prompt == formatter_response.prompt
47+
), "The formatted prompt does not match the expected output."
48+
49+
50+
# Optionally, include a test for the 'stop' if it's part of the functionality.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.