diff --git a/.gitmodules b/.gitmodules index 8470285..40f3449 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ [submodule "vendor/gemma.cpp"] path = vendor/gemma.cpp url = https://github.com/google/gemma.cpp -[submodule "vendor/pybind11"] - path = vendor/pybind11 - url = https://github.com/pybind/pybind11.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 1494591..2d7026e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,8 +21,7 @@ FetchContent_Declare(pybind11 GIT_REPOSITORY https://github.com/pybind/pybind11. FetchContent_MakeAvailable(pybind11) # Create the Python module -pybind11_add_module(pygemma src/gemma_binding.cpp) - +add_library(pygemma SHARED src/gemma_binding.cpp) target_link_libraries(pygemma PRIVATE libgemma hwy hwy_contrib sentencepiece) # Link against libgemma.a and any other necessary libraries diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 59a8043..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -pybind11 -pre-commit diff --git a/setup.py b/setup.py index a69d9f9..21fead5 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ def build_extension(self, ext): build_args += [ "--", "-j", - "12", + "6", ] # Specifies the number of jobs to run simultaneously if not os.path.exists(self.build_temp): diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 53607db..25b382d 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -1,5 +1,3 @@ -#include -#include // #include "gemma.h" // Adjust include path as necessary #include #include @@ -19,7 +17,6 @@ #include "hwy/profiler.h" #include "hwy/timer.h" -namespace py = pybind11; namespace gcpp { @@ -387,10 +384,8 @@ std::string chat_base_wrapper(const std::vector &args) } -PYBIND11_MODULE(pygemma, m) +int main(int argc, char **argv) { - m.doc() = "Pybind11 integration for chat_base function"; - m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments"); - m.def("show_help", &show_help_wrapper, "A wrapper for show_help function"); - m.def("completion", &completion_base_wrapper, "A wrapper for inference function"); + chat_base(argc, argv); + return 0; } diff --git a/vendor/pybind11 b/vendor/pybind11 deleted file mode 160000 index 8b48ff8..0000000 --- a/vendor/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8b48ff878c168b51fe5ef7b8c728815b9e1a9857