diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..e75267d --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,58 @@ +name: Tests + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + download-model: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Download model from Kaggle + run: | + mkdir -p model + curl -L -u ${{ secrets.KAGGLE_USERNAME }}:${{ secrets.KAGGLE_KEY }} \ + -o model/model.tar.gz \ + "https://www.kaggle.com/api/v1/models/google/gemma/gemmaCpp/2b-it-mqa/1/download" + shell: bash + - name: Upload model as artifact + uses: actions/upload-artifact@v4 + with: + name: model + path: model/ + + tests: + needs: download-model + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Download model artifact + uses: actions/download-artifact@v4 + with: + name: model + path: model/ + - name: Uncompress model files + run: | + tar -xzf model/model.tar.gz -C model + rm model/model.tar.gz + shell: bash + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] -v + - name: Test with pytest + run: pytest tests/ diff --git a/.gitignore b/.gitignore index 551b2a3..0349f1a 100644 --- a/.gitignore +++ b/.gitignore @@ -162,5 +162,9 @@ cython_debug/ # Vscode .vscode/ -#p Precommit +# Project .pre-commit-config.yaml +models/ +fixed_wheels +playground +db.json diff --git a/CMakeLists.txt b/CMakeLists.txt index 60b0238..2f206bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,27 +8,16 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) -FetchContent_MakeAvailable(sentencepiece) - -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG 8fb44ed6dd123f63dca95c20c561e8ca1de511d7) +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG 7122afed5a89c082fac028ab152cc50af3e57386) FetchContent_MakeAvailable(gemma) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) -FetchContent_MakeAvailable(highway) - FetchContent_Declare(pybind11 GIT_REPOSITORY https://github.com/pybind/pybind11.git GIT_TAG v2.10.4) FetchContent_MakeAvailable(pybind11) # Create the Python module -pybind11_add_module(pygemma src/gemma_binding.cpp) +pybind11_add_module(_pygemma src/_pygemma/gemma_binding.cpp) -target_link_libraries(pygemma PRIVATE libgemma hwy hwy_contrib sentencepiece) +target_link_libraries(_pygemma PRIVATE libgemma) -# Link against libgemma.a and any other necessary libraries FetchContent_GetProperties(gemma) -FetchContent_GetProperties(sentencepiece) -target_include_directories(pygemma PRIVATE ${gemma_SOURCE_DIR}) -target_include_directories(pygemma PRIVATE ${sentencepiece_SOURCE_DIR}) -target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) -target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) +target_include_directories(_pygemma PRIVATE ${gemma_SOURCE_DIR}) diff --git a/README.md b/README.md index d5529ff..fddb261 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # gemma-cpp-python: Python Bindings for [gemma.cpp](https://github.com/google/gemma.cpp) -**Latest Version: v0.1.3** +**Latest Version: v0.1.3.post3** +- Fixed absolute path for libsentencepiece.0.0.0.dylib - Interface changes due to updates in gemma.cpp. - Enhanced user experience for ease of use 🙏. Give it a try! @@ -11,10 +12,26 @@ ## 🙏 Acknowledgments Special thanks to the creators and contributors of [gemma.cpp](https://github.com/google/gemma.cpp) for their foundational work. +## 💬 Demo Chat and Chat with Website! +Check out the new chat demo included in the examples directory! This interactive interface showcases how you can engage in real-time conversations with the Gemma model. + +For the Chat with Website, please visit the [tutorial](examples/webchat/README.md) for more detail/ + + +### Using Gemma to chat with website +![Gemma Cpp Python Chat with Website Demo](asset/demo_chat_website.png) + +### Chat with Gemma +![Gemma Cpp Python Chat Demo](asset/demo_chat.png) + ## 🛠 Installation `Prerequisites`: Ensure Python 3.8+ and pip are installed. +`System requirements`: For now, I only tested it on the Unix-like Platforms and the MacOS. Please visit the [gemma.cpp installation](https://github.com/google/gemma.cpp?tab=readme-ov-file#system-requirements) for more details. + +`Models`: pygemma supported 2b-it-sfp model for now, to install model, [please visit here](https://github.com/google/gemma.cpp?tab=readme-ov-file#step-1-obtain-model-weights-and-tokenizer-from-kaggle-or-hugging-face-hub) + ### Install from PyPI For a quick setup, install directly from PyPI: ```bash @@ -49,6 +66,13 @@ gemma.load_model("/path/to/tokenizer", "/path/to/compressed_weight/", "model_typ gemma.completion("Write a poem") ``` +To run the demo on your local machine: +```bash +cd gemma-cpp-python/examples +pip install -r requirements.txt +streamlit run streamlit_demo.py +``` + ## 🤝 Contributing Contributions are welcome. Please clone the repository, push your changes to a new branch, and submit a pull request. diff --git a/asset/demo_chat.png b/asset/demo_chat.png new file mode 100644 index 0000000..64590a8 Binary files /dev/null and b/asset/demo_chat.png differ diff --git a/asset/demo_chat_website.png b/asset/demo_chat_website.png new file mode 100644 index 0000000..d4de4cb Binary files /dev/null and b/asset/demo_chat_website.png differ diff --git a/examples/example.py b/examples/example.py new file mode 100644 index 0000000..40cc361 --- /dev/null +++ b/examples/example.py @@ -0,0 +1,26 @@ +from pygemma import Gemma, ModelType, ModelTraining +from time import time + +TOKENIZER_PATH = "../model/tokenizer.spm" +COMPRESSED_WEIGHTS_PATH = "../model/2b-it-mqa.sbs" +MODEL_TYPE = ModelType.Gemma2B +MODEL_TRAINING = ModelTraining.GEMMA_IT + + +def main(): + gemma = Gemma( + tokenizer_path=TOKENIZER_PATH, + compressed_weights_path=COMPRESSED_WEIGHTS_PATH, + model_type=MODEL_TYPE, + model_training=MODEL_TRAINING, + ) + + start_time = time() + res = gemma("Hello world!") + print(f"Generated: {res}") + + print(f"Elapsed time: {time() - start_time:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/examples/simple-webchat/requirements.txt b/examples/simple-webchat/requirements.txt new file mode 100644 index 0000000..12a4706 --- /dev/null +++ b/examples/simple-webchat/requirements.txt @@ -0,0 +1 @@ +streamlit diff --git a/examples/simple-webchat/streamlit_demo.py b/examples/simple-webchat/streamlit_demo.py new file mode 100644 index 0000000..dfc3b78 --- /dev/null +++ b/examples/simple-webchat/streamlit_demo.py @@ -0,0 +1,91 @@ +import time +import streamlit as st +from pygemma import Gemma + +st.set_page_config(page_title="Gemma 💬") +st.title("Gemma Cpp Python Chat Demo 🎈") + +# Initialize session state for the model and its load status +if "model_loaded" not in st.session_state: + st.session_state["model_loaded"] = False + st.session_state["gemma"] = None + st.session_state["messages"] = [ + {"role": "assistant", "content": "How may I help you?"} + ] + + +@st.cache_resource +def load_gemma_model(tokenizer_path, weights_path, model_type): + gemma = Gemma() + gemma.load_model(tokenizer_path, weights_path, model_type) + return gemma + + +# Sidebar for model configuration +with st.sidebar: + st.title("Gemma Config") + tokenizer_path = st.text_input( + "Tokenizer path", value="", placeholder="tokenizer.spm" + ) + weights_path = st.text_input( + "Compressed weights path", value="", placeholder="2b-it-sfp.sbs" + ) + model_type = st.text_input("Model type", value="2b-it", placeholder="2b-it") + + # Load model button in the sidebar + if st.button("Load Model"): + st.session_state["gemma"] = load_gemma_model( + tokenizer_path, weights_path, model_type + ) + st.session_state["model_loaded"] = True + + # Indicate whether the model is loaded + if st.session_state["model_loaded"]: + st.sidebar.success("Model Loaded Successfully!") + else: + st.sidebar.warning('Model Not Loaded. Click "Load Model" to load the model.') + + st.markdown( + "📖 Check the detail at [gemma-cpp-python](https://github.com/namtranase/gemma-cpp-python)!" + ) + +# Store LLM generated responses +if "messages" not in st.session_state: + st.session_state.messages = [ + {"role": "assistant", "content": "How may I help you?"} + ] + +# Store LLM generated responses +if "messages" not in st.session_state.keys(): + st.session_state.messages = [ + {"role": "assistant", "content": "How may I help you?"} + ] + +# Display chat messages +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.write(message["content"]) + +# Function for generating LLM response +def generate_response(prompt_input): + # Hugging Face Login + if st.session_state.model_loaded: + return st.session_state.gemma.completion(prompt_input) + else: + return "Please load the model first." + + +# User-provided prompt +if prompt := st.chat_input(disabled=not (st.session_state["model_loaded"])): + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.write(prompt) + +# Generate a new response if last message is not from assistant +if st.session_state.messages[-1]["role"] != "assistant": + with st.chat_message("assistant"): + with st.spinner("Thinking..."): + response = generate_response(prompt) + st.write(response) + message = {"role": "assistant", "content": response} + st.session_state.messages.append(message) diff --git a/examples/webchat/README.md b/examples/webchat/README.md new file mode 100644 index 0000000..2db7ef7 --- /dev/null +++ b/examples/webchat/README.md @@ -0,0 +1,24 @@ +## 🌐 Chat with Website Feature + +Gemma branches out into the web with our `Chat with Website` feature. Strike up a conversation with any website and let Gemma extract the essence for a delightful chat experience. + +Special thank to the authors of repo: [scrapeGPT](https://github.com/LexiestLeszek/scrapeGPT), we based on scrapeGPT to build our demo! + +### Quick Start + +1. Launch the Gemma Chat Demo. +2. Plug in the tokenizer and weights paths, set the model type in the sidebar, and hit 'Load Model'. +3. Navigate to 'Website Processing', input a website URL, and press 'Process Website'. +4. Enjoy the interactive session as Gemma digests web content for a smart chat. + +Dive into a seamless blend of AI interaction and web content with just a few clicks! + +```bash +# To get started right away: +cd examples/webchat +pip install -r requirements.txt +streamlit run streamlit_webchat.py +``` + + +![Chat with Website Demo](../../asset/demo_chat_website.png) diff --git a/examples/webchat/__init__.py b/examples/webchat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/webchat/process_web.py b/examples/webchat/process_web.py new file mode 100644 index 0000000..75350f1 --- /dev/null +++ b/examples/webchat/process_web.py @@ -0,0 +1,219 @@ +import json +import time +from datetime import datetime +from io import BytesIO +from urllib.parse import urljoin, urlparse + +import requests +from bs4 import BeautifulSoup +from fp.fp import FreeProxy +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import Chroma +from PyPDF2 import PdfReader + +# Proxy init +def get_proxy(): + print("Starting proxy ...") + proxy_url = FreeProxy( + country_id=[ + "US", + "CA", + "FR", + "NZ", + "SE", + "PT", + "CZ", + "NL", + "ES", + "SK", + "UK", + "PL", + "IT", + "DE", + "AT", + "JP", + ], + https=True, + rand=True, + timeout=3, + ).get() + proxy_obj = {"server": proxy_url, "username": "", "password": ""} + + print(f"Proxy generated: {proxy_url}") + + return proxy_obj + + +def save_to_db(text, url): + timestamp = datetime.now().isoformat() + # Load existing data from db.json + try: + with open("db.json", "r") as f: + data = json.load(f) + except FileNotFoundError: + data = [] + + # Create a new entry with the domain name as key + website = {"date": timestamp, "text": text} + new_entry = {"start_url": url, "data": website} + + # Append new entry to the data list + data.append(new_entry) + + # Write data back to db.json + with open("db.json", "w") as f: + json.dump(data, f, indent=4) + + +def scrape_webpages(urls, proxy): + print("Scraping text from webpages from each of the links ...") + scraped_texts = [] + for url in urls: + try: + if url.endswith(".pdf"): + response = requests.get(url, proxies=proxy) + reader = PdfReader(BytesIO(response.content)) + number_of_pages = len(reader.pages) + + for p in range(number_of_pages): + + page = reader.pages[p] + text = page.extract_text() + scraped_texts.append(text) + else: + page = requests.get(url, proxies=proxy) + soup = BeautifulSoup(page.content, "html.parser") + text = " ".join([p.get_text() for p in soup.find_all("p")]) + scraped_texts.append(text) + + except Exception as e: + print(f"Failed to scrape {url}: {e}") + + all_scraped_text = "\n".join(scraped_texts) + print("Finished scraping the text from webpages!") + return all_scraped_text + + +def get_domain(url): + return urlparse(url).netloc + + +def get_robots_file(url, proxy): + robots_url = urljoin(url, "/robots.txt") + try: + response = requests.get(robots_url, proxies=proxy) + return response.text + except Exception as e: + print(f"Error fetching robots.txt: {e}") + return None + + +def parse_robots(content): + # This function assumes simple rules without wildcards, comments, etc. + # For a full parser, consider using a library like robotparser. + disallowed = [] + for line in content.splitlines(): + if line.startswith("Disallow:"): + path = line[len("Disallow:") :].strip() + disallowed.append(path) + return disallowed + + +def is_allowed(url, disallowed_paths, base_domain): + parsed_url = urlparse(url) + if parsed_url.netloc != base_domain: + return False + for path in disallowed_paths: + if parsed_url.path.startswith(path): + return False + return True + + +def scrape_site_links(url, proxy): + visited_links = set() + not_visited_links = set() + to_visit = [url] + base_domain = get_domain(url) + disallowed_paths = parse_robots(get_robots_file(url, proxy)) + last_found_time = time.time() # Track the last time a link was found + + while to_visit: + # Break the loop if 30 seconds have passed without finding a new link + if time.time() - last_found_time > 15: + print("FINISHED scraping the links") + break + + current_url = to_visit.pop(0) + if current_url not in visited_links and is_allowed( + current_url, disallowed_paths, base_domain + ): + visited_links.add(current_url) + try: + print(f"{current_url}") + response = requests.get(current_url, proxies=proxy) + soup = BeautifulSoup(response.text, "html.parser") + for link in soup.find_all("a", href=True): + new_url = urljoin(current_url, link["href"]) + if new_url not in visited_links: + to_visit.append(new_url) + last_found_time = time.time() # Update the last found time + except Exception as e: + print(f" !!! COULD NOT VISIT: {current_url}") + not_visited_links.add(current_url) + + return visited_links + + +class WebProcesser: + def __init__(self) -> None: + self.chunk_size = (500,) + self.chunk_overlap = (100,) + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=500, chunk_overlap=100 + ) + self.embedding = HuggingFaceEmbeddings( + model_name="sentence-transformers/all-MiniLM-L6-v2" + ) + self.db = None + self.retriever = None + self.db_path = "db.json" + db_file = json.dumps([]) + with open(self.db_path, "w") as outfile: + outfile.write(db_file) + + def init_db_website(self, url): + web_text = "" + try: + with open(self.db_path, "r") as f: + data = json.load(f) + for entry in data: + if ( + url in entry["start_url"] + ): # ADD check for today's scraped website data, not longer + print("Website is already scraped today!") + web_text = entry["data"]["text"] + except FileNotFoundError: + data = [] + # Check if website already in the db + if not web_text: + proxy = get_proxy() + # Scrape all the links from the given start URL using the proxy + all_links = scrape_site_links(url, proxy) + + # Scrape the content from all the links obtained, using the proxy + web_text = scrape_webpages(all_links, proxy) + save_to_db(web_text, url) + + documents = self.text_splitter.split_text(str(web_text)) + self.db = Chroma.from_texts(documents, embedding=self.embedding) + self.retriever = self.db.as_retriever(search_kwargs={"k": 3}) + return True + + def get_context(self, question, chunk_size=500, chunk_overlap=100): + """Get context from question and txt file""" + print("Embedding model started ...") + context = self.retriever.get_relevant_documents(question) + print(f"Emdeggind Model returned: {context}") + + return context diff --git a/examples/webchat/requirements.txt b/examples/webchat/requirements.txt new file mode 100644 index 0000000..3e95c9c --- /dev/null +++ b/examples/webchat/requirements.txt @@ -0,0 +1,8 @@ +aiogram==2.22.1 +beautifulsoup4==4.11.1 +free_proxy==1.1.1 +langchain==0.1.6 +langchain_community==0.0.19 +PyPDF2==3.0.1 +sentence-transformers==2.6.0 +chromadb==0.4.24 diff --git a/examples/webchat/streamlit_webchat.py b/examples/webchat/streamlit_webchat.py new file mode 100644 index 0000000..55b4173 --- /dev/null +++ b/examples/webchat/streamlit_webchat.py @@ -0,0 +1,122 @@ +import time +import streamlit as st +from pygemma import Gemma +from process_web import WebProcesser + + +st.set_page_config(page_title="Chat with Website 💬") +st.title("Gemma Cpp Python Chat with Website Demo 🎈") + +# Initialize session state for the model and its load status +if "model_loaded" not in st.session_state: + st.session_state["model_loaded"] = False + st.session_state["website_loaded"] = False + st.session_state["gemma"] = None + st.session_state["web_processer"] = None + st.session_state["messages"] = [ + {"role": "assistant", "content": "How may I help you?"} + ] + + +@st.cache_resource +def load_gemma_model(tokenizer_path, weights_path, model_type): + gemma = Gemma() + gemma.load_model(tokenizer_path, weights_path, model_type) + return gemma + + +@st.cache_resource +def load_web_processer(): + web_processor = WebProcesser() + return web_processor + + +# Sidebar for model configuration +with st.sidebar: + st.title("Gemma Config") + tokenizer_path = st.text_input( + "Tokenizer path", value="tokenizer.spm", placeholder="tokenizer.spm" + ) + weights_path = st.text_input( + "Compressed weights path", value="2b-it-sfp.sbs", placeholder="2b-it-sfp.sbs" + ) + model_type = st.text_input("Model type", value="2b-it", placeholder="2b-it") + + # Load model button in the sidebar + if st.button("Load Model"): + st.session_state["gemma"] = load_gemma_model( + tokenizer_path, weights_path, model_type + ) + st.session_state["model_loaded"] = True + + # Indicate whether the model is loaded + if st.session_state["model_loaded"]: + st.sidebar.success("Model Loaded Successfully!") + else: + st.sidebar.warning('Model Not Loaded. Click "Load Model" to load the model.') + + st.markdown("## Website Processing") + website_url = st.text_input( + "Website URL", + value="https://namtranase.github.io/terminalmind/", + placeholder="Enter website URL", + ) + + if st.button("Process Website"): + st.session_state["web_processor"] = WebProcesser() + if website_url: + # Placeholder for the function to process website data + st.session_state["web_processor"].init_db_website(website_url) + st.session_state["website_loaded"] = True + st.sidebar.success("Website processed successfully, Now you can ask!") + else: + st.sidebar.error("Please enter a valid website URL.") + + st.markdown( + "📖 Check the detail at [gemma-cpp-python](https://github.com/namtranase/gemma-cpp-python)!" + ) + +# Store LLM generated responses +if "messages" not in st.session_state: + st.session_state.messages = [ + {"role": "assistant", "content": "How may I help you?"} + ] + +# Store LLM generated responses +if "messages" not in st.session_state.keys(): + st.session_state.messages = [ + {"role": "assistant", "content": "How may I help you?"} + ] + +# Display chat messages +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.write(message["content"]) + +# Function for generating LLM response +def generate_response(prompt_input): + if st.session_state.model_loaded and st.session_state.website_loaded: + context = st.session_state["web_processor"].get_context(prompt_input) + prompt = f"""Use the following pieces of context to answer the question at the end. + Context: {context}.\n + Question: {prompt_input} + Helpful Answer:""" + return st.session_state.gemma.completion(prompt) + else: + return "Please load the model first." + + +# User-provided prompt +if prompt := st.chat_input(disabled=not (st.session_state["model_loaded"])): + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.write(prompt) + +# Generate a new response if last message is not from assistant +if st.session_state.messages[-1]["role"] != "assistant": + with st.chat_message("assistant"): + with st.spinner("Thinking..."): + response = generate_response(prompt) + st.write(response) + message = {"role": "assistant", "content": response} + st.session_state.messages.append(message) diff --git a/model/.gitignore b/model/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 2fd725a..7e26269 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,24 @@ [build-system] requires = ["setuptools", "wheel", "cmake"] build-backend = "setuptools.build_meta" + +[project] +name = "pygemma" +version = "0.1.3.post3" +authors = [ + {name = "Nam Tran", email = "namtran.ase@gmail.com"}, +] +description = "Python bindings for the gemma.cpp library" +readme = "README.md" +license = { text = "MIT" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +requires-python = ">=3.8" + +[project.optional-dependencies] +test = [ + "pytest>=8.1.1" +] \ No newline at end of file diff --git a/scripts/release.sh b/scripts/release.sh index 6b0caf0..697bc30 100755 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -1,3 +1,21 @@ +#!/bin/bash + +# Remove previous distribution files rm -rf dist/* + +# Build the distribution python3 setup.py sdist bdist_wheel -delocate-wheel -w fixed_wheels -v dist/*.whl + +# Check the operating system +OS="$(uname)" +if [ "$OS" = "Darwin" ]; then + # macOS specific commands + delocate-wheel -w fixed_wheels -v dist/*.whl + mv fixed_wheels/*.whl dist/ +elif [ "$OS" = "Linux" ]; then + # Linux specific commands (if any can be added here) + echo "Linux OS detected. No additional steps required for Linux." +fi + +# Upload the distribution to PyPI +twine upload dist/* \ No newline at end of file diff --git a/setup.py b/setup.py index 030cffe..baaac66 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,10 @@ import os +import platform import subprocess import sys -from setuptools import setup, find_packages, Extension + +from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -import platform class CMakeExtension(Extension): @@ -27,23 +28,26 @@ def run(self): def build_extension(self, ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + os.makedirs(self.build_temp, exist_ok=True) + cmake_args = [ - "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, - "-DPYTHON_EXECUTABLE=" + sys.executable, + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", ] + # Allow gemma.cpp to be built on Windows with ClangCL + # Refer to https://github.com/google/gemma.cpp/pull/6 + if platform.system() == "Windows": + cmake_args += ["-T", "ClangCL"] + cfg = "Debug" if self.debug else "Release" build_args = ["--config", cfg] - # Add a parallel build option - build_args += [ - "--", - "-j", - "12", - ] - - if not os.path.exists(self.build_temp): - os.makedirs(self.build_temp) + if platform.system() == "Windows": + build_args += ["--", "/m:12"] + else: + build_args += ["--", "-j12"] subprocess.check_call( ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp @@ -55,23 +59,6 @@ def build_extension(self, ext): setup( - name="pygemma", - version="0.1.3", - author="Nam Tran", - author_email="namtran.ase@gmail.com", - description="A Python package with a C++ backend using gemma.cpp", - long_description=""" - This package provides Python bindings to a C++ library using pybind11. - """, - long_description_content_type="text/markdown", - ext_modules=[CMakeExtension("pygemma")], + ext_modules=[CMakeExtension("_pygemma")], cmdclass=dict(build_ext=CMakeBuild), - zip_safe=False, - packages=find_packages(), - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - python_requires=">=3.8", ) diff --git a/src/_pygemma/__init__.pyi b/src/_pygemma/__init__.pyi new file mode 100644 index 0000000..56d2bf3 --- /dev/null +++ b/src/_pygemma/__init__.pyi @@ -0,0 +1,35 @@ +from typing import List + +class GemmaModel: + def __init__( + self, + tokenizer_path: str, + compressed_weights_path: str, + model_type: int, + model_training: int, + n_threads: int, + ) -> None: + pass + + @property + def bos_token(self) -> int: ... + @property + def eos_token(self) -> int: ... + def generate( + self, + prompt: str, + max_tokens: int, + max_generated_tokens: int, + temperature: float, + seed: int, + verbosity: int, + ) -> str: ... + def tokenize( + self, + text: str, + add_bos: bool = True, + ) -> List[int]: ... + def detokenize( + self, + tokens: List[int], + ) -> str: ... diff --git a/src/_pygemma/gemma_binding.cpp b/src/_pygemma/gemma_binding.cpp new file mode 100644 index 0000000..beb8232 --- /dev/null +++ b/src/_pygemma/gemma_binding.cpp @@ -0,0 +1,142 @@ +#include "gemma_binding.h" + +#include +#include + +#include + +#include "gemma.h" // Gemma +#include "util/app.h" +#include "util/args.h" + +namespace py = pybind11; + +GemmaModel::GemmaModel(const char *tokenizer_path_str, + const char *compressed_weights_path_str, + int model_type_id, int training_id, int num_threads) { + const gcpp::Path tokenizer_path = gcpp::Path{tokenizer_path_str}; + const gcpp::Path compressed_weights_path = + gcpp::Path{compressed_weights_path_str}; + const gcpp::Path weights_path = gcpp::Path{""}; + + this->model_type = static_cast(model_type_id); + this->model_training = static_cast(training_id); + this->num_threads = static_cast(num_threads); + + pool = new hwy::ThreadPool(num_threads); + + kv_cache = CreateKVCache(model_type); + + // For many-core, pinning threads to cores helps. + if (this->num_threads > 10) { + gcpp::PinThreadToCore(this->num_threads - 1); // Main thread + + pool->Run(0, pool->NumThreads(), [](uint64_t /*task*/, size_t thread) { + gcpp::PinThreadToCore(thread); + }); + } + + model = new gcpp::Gemma(tokenizer_path, compressed_weights_path, weights_path, + model_type, model_training, *pool); +} + +GemmaModel::~GemmaModel() { delete model; } + +int GemmaModel::get_bos_token() const { return bos_token; } + +int GemmaModel::get_eos_token() const { return eos_token; } + +std::vector GemmaModel::tokenize(const std::string &text, + const bool add_bos) { + std::vector tokens; + + if (!model->Tokenizer()->Encode(text, &tokens).ok()) { + throw std::runtime_error("Tokenization failed"); + } + + if (add_bos) { + tokens.insert(tokens.begin(), bos_token); + } + + return tokens; +} + +std::string GemmaModel::detokenize(const std::vector &tokens) { + std::string text; + if (!model->Tokenizer()->Decode(tokens, &text).ok()) { + throw std::runtime_error("Detokenization failed"); + } + return text; +} + +std::string GemmaModel::generate(const std::string &prompt_string, + size_t max_tokens, size_t max_generated_tokens, + float temperature, uint_fast32_t seed, + int verbosity) { + size_t pos = 0; // KV Cache position + + // Initialize random number generator + std::mt19937 gen; + gen.seed(seed); + + const std::string formatted = [&]() { + if (model_training == gcpp::ModelTraining::GEMMA_IT) { + return "user\n" + prompt_string + + "\nmodel\n"; + } + return prompt_string; + }(); + + std::vector tokens = tokenize(formatted, true); + size_t ntokens = tokens.size(); + + std::string completion; + + // This callback function gets invoked everytime a token is generated + const gcpp::StreamFunc stream_token = [this, &pos, &gen, &ntokens, + tokenizer = model->Tokenizer(), + &completion](int token, float) { + ++pos; + if (pos < ntokens) { + // print feedback + } else if (token != this->eos_token) { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + completion += token_text; + } + return true; + }; + + gcpp::GenerateGemma(*model, + {.max_tokens = max_tokens, + .max_generated_tokens = max_generated_tokens, + .temperature = temperature, + .verbosity = verbosity}, + tokens, /*KV cache position = */ 0, kv_cache, *pool, + stream_token, gen); + + return completion; +} + +PYBIND11_MODULE(_pygemma, m) { + m.doc() = "Python binding for gemma.cpp"; + + py::class_(m, "GemmaModel") + .def_property_readonly("bos_token", &GemmaModel::get_bos_token, + "Get the BOS token") + .def_property_readonly("eos_token", &GemmaModel::get_eos_token, + "Get the EOS token") + .def(py::init(), + py::arg("tokenizer_path"), py::arg("compressed_weights_path"), + py::arg("model_type"), py::arg("model_training"), + py::arg("num_threads"), "Initialize the Gemma model") + .def("tokenize", &GemmaModel::tokenize, py::arg("text"), + py::arg("add_bos"), + "Tokenize the input text and return the tokenized text") + .def("detokenize", &GemmaModel::detokenize, py::arg("tokens"), + "Detokenize the input tokens and return the detokenized text") + .def("generate", &GemmaModel::generate, py::arg("prompt"), + py::arg("max_tokens"), py::arg("max_generated_tokens"), + py::arg("temperature"), py::arg("seed"), py::arg("verbosity"), + "Generate text based on the input prompt"); +} \ No newline at end of file diff --git a/src/_pygemma/gemma_binding.h b/src/_pygemma/gemma_binding.h new file mode 100644 index 0000000..db8402f --- /dev/null +++ b/src/_pygemma/gemma_binding.h @@ -0,0 +1,35 @@ +#include "gemma.h" // Gemma + +#pragma once +class GemmaModel { + private: + gcpp::Gemma *model; + gcpp::Model model_type; + gcpp::ModelTraining model_training; + + size_t num_threads; + hwy::ThreadPool *pool; + gcpp::KVCache kv_cache; + + const int eos_token = 1; + const int bos_token = 2; + + public: + GemmaModel(const char *tokenizer_path_str, + const char *compressed_weights_path_str, int model_type_id, + int training_id, int num_threads); + + ~GemmaModel(); + + int get_bos_token() const; + + int get_eos_token() const; + + std::vector tokenize(const std::string &text, const bool add_bos = true); + + std::string detokenize(const std::vector &tokens); + + std::string generate(const std::string &prompt, size_t max_tokens, + size_t max_generated_tokens, float temperature, + uint_fast32_t seed, int verbosity); +}; \ No newline at end of file diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp deleted file mode 100644 index 5e177c5..0000000 --- a/src/gemma_binding.cpp +++ /dev/null @@ -1,324 +0,0 @@ -#include -#include - -#include "gemma_binding.h" -namespace py = pybind11; - -static constexpr std::string_view kAsciiArtBanner = - " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" - " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" - "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" - " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" - " __/ | | | | |\n" - " |___/ |_| |_|"; - -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); - - if (app.verbosity >= 2) { - time_t now = time(nullptr); - char* dt = ctime(&now); // NOLINT - std::cout << "Date & Time : " << dt - << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize - << "\n" - << "Hardware concurrency : " - << std::thread::hardware_concurrency() << std::endl - << "Instruction set : " - << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" << "\n" - << "Compiled config : " << CompiledConfig() << "\n" - << "Weight Type : " - << gcpp::TypeName(gcpp::WeightT()) << "\n" - << "EmbedderInput Type : " - << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; - } -} - -void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, - gcpp::AppArgs& app) { - std::cerr - << kAsciiArtBanner - << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" - "==========================================================\n\n" - "To run gemma.cpp, you need to " - "specify 3 required model loading arguments:\n --tokenizer\n " - "--compressed_weights\n" - " --model.\n"; - std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " - "--compressed_weights 2b-it-sfp.sbs --model 2b-it\n"; - std::cerr << "\n*Model Loading Arguments*\n\n"; - loader.Help(); - std::cerr << "\n*Inference Arguments*\n\n"; - inference.Help(); - std::cerr << "\n*Application Arguments*\n\n"; - app.Help(); - std::cerr << "\n"; -} - -void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const InferenceArgs& args, int verbosity, - const gcpp::AcceptFunc& accept_token, std::string& eot_line) { - PROFILER_ZONE("Gen.misc"); - int abs_pos = 0; // absolute token index over all turns - int current_pos = 0; // token index within the current turn - int prompt_size{}; - - std::mt19937 gen; - if (args.deterministic) { - gen.seed(42); - } else { - std::random_device rd; - gen.seed(rd()); - } - - // callback function invoked for each generated token. - auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = model.Tokenizer(), - verbosity](int token, float) { - ++abs_pos; - ++current_pos; - if (current_pos < prompt_size) { - std::cerr << "." << std::flush; - } else if (token == gcpp::EOS_ID) { - if (!args.multiturn) { - abs_pos = 0; - if (args.deterministic) { - gen.seed(42); - } - } - if (verbosity >= 2) { - std::cout << "\n[ End ]\n"; - } - } else { - std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); - // +1 since position is incremented above - if (current_pos == prompt_size + 1) { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (verbosity >= 1) { - std::cout << std::endl << std::endl; - } - } - std::cout << token_text << std::flush; - } - return true; - }; - - while (abs_pos < args.max_tokens) { - std::string prompt_string; - std::vector prompt; - current_pos = 0; - { - PROFILER_ZONE("Gen.input"); - if (verbosity >= 1) { - std::cout << "> " << std::flush; - } - - if (eot_line.size() == 0) { - std::getline(std::cin, prompt_string); - } else { - std::string line; - while (std::getline(std::cin, line)) { - if (line == eot_line) { - break; - } - prompt_string += line + "\n"; - } - } - } - - if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { - return; - } - - if (prompt_string == "%c" || prompt_string == "%C") { - abs_pos = 0; - continue; - } - - if (model.model_training == ModelTraining::GEMMA_IT) { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; - if (abs_pos > 0) { - // Prepend "" token if this is a multi-turn dialogue - // continuation. - prompt_string = "\n" + prompt_string; - } - } - - HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); - - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - if (abs_pos == 0) { - prompt.insert(prompt.begin(), 2); - } - - prompt_size = prompt.size(); - - std::cerr << std::endl << "[ Reading prompt ] " << std::flush; - - const double time_start = hwy::platform::Now(); - GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, - stream_token, accept_token, gen, verbosity); - const double time_end = hwy::platform::Now(); - const double tok_sec = current_pos / (time_end - time_start); - if (verbosity >= 2) { - std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" - << std::endl - << tok_sec << " tokens / sec" << std::endl; - } - std::cout << std::endl << std::endl; - } - std::cout - << "max_tokens (" << args.max_tokens - << ") exceeded. Use a larger value if desired using the --max_tokens " - << "command line flag.\n"; -} - - -std::vector tokenize( - const std::string& prompt_string, - const sentencepiece::SentencePieceProcessor* tokenizer) { - std::string formatted = "user\n" + prompt_string + - "\nmodel\n"; - std::vector tokens; - HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok()); - tokens.insert(tokens.begin(), 2); // BOS token - return tokens; -} - -int GemmaWrapper::completionPrompt(std::string& prompt) { - size_t pos = 0; // KV Cache position - size_t num_threads = static_cast(std::clamp( - static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); - hwy::ThreadPool pool(num_threads); - // Initialize random number generator - std::mt19937 gen; - std::random_device rd; - gen.seed(rd()); - - // Tokenize instruction - std::vector tokens = - tokenize(prompt, this->m_model->Tokenizer()); - size_t ntokens = tokens.size(); - - // This callback function gets invoked everytime a token is generated - auto stream_token = [&pos, &gen, &ntokens, tokenizer = this->m_model->Tokenizer()]( - int token, float) { - ++pos; - if (pos < ntokens) { - // print feedback - } else if (token != gcpp::EOS_ID) { - std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); - std::cout << token_text << std::flush; - } - return true; - }; - - GenerateGemma(*this->m_model, - {.max_tokens = 2048, - .max_generated_tokens = 1024, - .temperature = 1.0, - .verbosity = 0}, - tokens, /*KV cache position = */ 0, this->m_kvcache, pool, - stream_token, gen); - std::cout << std::endl; -} - -void GemmaWrapper::loadModel(const std::vector &args) { - int argc = args.size() + 1; // +1 for the program name - std::vector argv_vec; - argv_vec.reserve(argc); - argv_vec.push_back(const_cast("pygemma")); - for (const auto &arg : args) - { - argv_vec.push_back(const_cast(arg.c_str())); - } - - char **argv = argv_vec.data(); - - this->m_loader = gcpp::LoaderArgs(argc, argv); - this->m_inference = gcpp::InferenceArgs(argc, argv); - this->m_app = gcpp::AppArgs(argc, argv); - - PROFILER_ZONE("Run.misc"); - - hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(this->m_app.num_threads); - // For many-core, pinning threads to cores helps. - if (this->m_app.num_threads > 10) { - PinThreadToCore(this->m_app.num_threads - 1); // Main thread - - pool.Run(0, pool.NumThreads(), - [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); - } - - if (!this->m_model) { - this->m_model.reset(new gcpp::Gemma(this->m_loader.tokenizer, this->m_loader.compressed_weights, this->m_loader.ModelType(), pool)); - } -// auto kvcache = CreateKVCache(loader.ModelType()); - this->m_kvcache = CreateKVCache(this->m_loader.ModelType()); - - if (const char* error = this->m_inference.Validate()) { - ShowHelp(this->m_loader, this->m_inference, this->m_app); - HWY_ABORT("\nInvalid args: %s", error); - } - - if (this->m_app.verbosity >= 1) { - const std::string instructions = - "*Usage*\n" - " Enter an instruction and press enter (%C resets conversation, " - "%Q quits).\n" + - (this->m_inference.multiturn == 0 - ? std::string(" Since multiturn is set to 0, conversation will " - "automatically reset every turn.\n\n") - : "\n") + - "*Examples*\n" - " - Write an email to grandma thanking her for the cookies.\n" - " - What are some historical attractions to visit around " - "Massachusetts?\n" - " - Compute the nth fibonacci number in javascript.\n" - " - Write a standup comedy bit about GPU programming.\n"; - - std::cout << "\033[2J\033[1;1H" // clear screen - << kAsciiArtBanner << "\n\n"; - ShowConfig(this->m_loader, this->m_inference, this->m_app); - std::cout << "\n" << instructions << "\n"; - } -} - -void GemmaWrapper::showConfig() { - ShowConfig(this->m_loader,this->m_inference, this->m_app); -} - -void GemmaWrapper::showHelp() { - ShowHelp(this->m_loader,this->m_inference, this->m_app); -} - - -PYBIND11_MODULE(pygemma, m) { - py::class_(m, "Gemma") - .def(py::init<>()) - .def("show_config", &GemmaWrapper::showConfig) - .def("show_help", &GemmaWrapper::showHelp) - .def("load_model", [](GemmaWrapper &self, - const std::string &tokenizer, - const std::string &compressed_weights, - const std::string &model) { - std::vector args = { - "--tokenizer", tokenizer, - "--compressed_weights", compressed_weights, - "--model", model - }; - self.loadModel(args); // Assuming GemmaWrapper::loadModel accepts std::vector - }, py::arg("tokenizer"), py::arg("compressed_weights"), py::arg("model")) - .def("completion", &GemmaWrapper::completionPrompt, "Function that completes given prompt."); -} diff --git a/src/gemma_binding.h b/src/gemma_binding.h deleted file mode 100644 index c02b70b..0000000 --- a/src/gemma_binding.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once -// Command line text interface to gemma. - -#include -#include -#include -#include -#include // NOLINT -#include - -// copybara:import_next_line:gemma_cpp -#include "compression/compress.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // HasHelp -// copybara:end -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/per_target.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" - -using namespace gcpp; - -class GemmaWrapper { - public: - // GemmaWrapper(); - void loadModel(const std::vector &args); // Consider exception safety - void showConfig(); - void showHelp(); - int completionPrompt(std::string& prompt); - - private: - gcpp::LoaderArgs m_loader = gcpp::LoaderArgs(0, nullptr); - gcpp::InferenceArgs m_inference = gcpp::InferenceArgs(0, nullptr); - gcpp::AppArgs m_app = gcpp::AppArgs(0, nullptr); - std::unique_ptr m_model; - KVCache m_kvcache; -}; diff --git a/src/pygemma/__init__.py b/src/pygemma/__init__.py new file mode 100644 index 0000000..20e242a --- /dev/null +++ b/src/pygemma/__init__.py @@ -0,0 +1 @@ +from .gemma import * diff --git a/src/pygemma/gemma.py b/src/pygemma/gemma.py new file mode 100644 index 0000000..6c29d12 --- /dev/null +++ b/src/pygemma/gemma.py @@ -0,0 +1,101 @@ +import multiprocessing +import os +from enum import Enum +import random +from typing import List, Optional + +import _pygemma + + +class ModelTraining(Enum): + GEMMA_IT = 0 + GEMMA_PT = 1 + + +class ModelType(Enum): + Gemma2B = 0 + Gemma7B = 1 + + +class Gemma: + def __init__( + self, + *, + tokenizer_path: str, + compressed_weights_path: str, + model_type: ModelType, + model_training: ModelTraining, + n_threads: Optional[int] = None, + ): + self.tokenizer_path = tokenizer_path + self.compressed_weights_path = compressed_weights_path + self.model_type = model_type + self.model_training = model_training + + self.n_threads = n_threads or max(multiprocessing.cpu_count() - 2, 1) + + if not os.path.exists(self.tokenizer_path): + raise FileNotFoundError(f"Tokenizer not found: {self.tokenizer_path}") + + if not os.path.exists(self.compressed_weights_path): + raise FileNotFoundError( + f"Compressed weights not found: {self.compressed_weights_path}" + ) + + self.model = _pygemma.GemmaModel( + self.tokenizer_path, + self.compressed_weights_path, + self.model_type.value, + self.model_training.value, + self.n_threads, + ) + + assert self.model + + @property + def bos_token(self) -> int: + assert self.model + return self.model.bos_token + + @property + def eos_token(self) -> int: + assert self.model + return self.model.eos_token + + def __call__( + self, + prompt: str, + *, + max_tokens: int = 2048, + max_generated_tokens: int = 1024, + temperature: float = 1.0, + seed: Optional[int] = None, + verbosity: int = 0, + ) -> str: + assert self.model + + seed = seed or random.randint(0, 2**32 - 1) + + return self.model.generate( + prompt, + max_tokens, + max_generated_tokens, + temperature, + seed, + verbosity, + ) + + def tokenize( + self, + text: str, + add_bos: bool = True, + ) -> List[int]: + assert self.model + return self.model.tokenize(text, add_bos) + + def detokenize( + self, + tokens: List[int], + ) -> str: + assert self.model + return self.model.detokenize(tokens) diff --git a/tests/test_chat.py b/tests/test_chat.py deleted file mode 100644 index 4b5efc8..0000000 --- a/tests/test_chat.py +++ /dev/null @@ -1,41 +0,0 @@ -import argparse -from pygemma import Gemma - - -def main(): - parser = argparse.ArgumentParser( - description="Test script for pygemma Python bindings." - ) - parser.add_argument( - "--tokenizer", type=str, required=True, help="Path to the tokenizer file." - ) - parser.add_argument( - "--compressed_weights", - type=str, - required=True, - help="Path to the compressed weights file.", - ) - parser.add_argument( - "--model", type=str, required=True, help="Model type identifier." - ) - parser.add_argument( - "--input", - type=str, - required=False, - help="Input text to chat with the model. If None, Switch to Chat mode.", - default="Hello.", - ) - # Now using the parsed arguments - args = parser.parse_args() - - gemma = Gemma() - gemma.show_config() - gemma.show_help() - gemma.load_model(args.tokenizer, args.compressed_weights, args.model) - gemma.completion("Write a poem") - gemma.completion("What is the best war in history") - - -if __name__ == "__main__": - main() - # python tests/test_chat.py --tokenizer ../Model_Weight/tokenizer.spm --compressed_weights ../Model_Weight/2b-it-sfp.sbs --model 2b-it diff --git a/tests/test_gemma.py b/tests/test_gemma.py new file mode 100644 index 0000000..286d138 --- /dev/null +++ b/tests/test_gemma.py @@ -0,0 +1,35 @@ +import os +from pygemma import Gemma, ModelType, ModelTraining + +# Get the directory that this file is in +dir_path = os.path.dirname(os.path.realpath(__file__)) + +TOKENIZER_PATH = os.path.join(dir_path, "../model/tokenizer.spm") +COMPRESSED_WEIGHTS_PATH = os.path.join(dir_path, "../model/2b-it-mqa.sbs") +MODEL_TYPE = ModelType.Gemma2B +MODEL_TRAINING = ModelTraining.GEMMA_IT + + +def test_gemma(): + gemma = Gemma( + tokenizer_path=TOKENIZER_PATH, + compressed_weights_path=COMPRESSED_WEIGHTS_PATH, + model_type=MODEL_TYPE, + model_training=MODEL_TRAINING, + ) + + assert gemma + assert gemma.model + + text = "Hello world!" + + tokens = gemma.tokenize(text) + assert tokens[0] == gemma.bos_token + assert tokens == [2, 4521, 2134, 235341] + detokenized = gemma.detokenize(tokens) + assert detokenized == text + + # without BOS + tokens_without_bos = gemma.tokenize(text, add_bos=False) + assert tokens_without_bos[0] != gemma.bos_token + assert tokens_without_bos == [4521, 2134, 235341]