From ef091dce8f9167a520299fa4fffc4323870a497a Mon Sep 17 00:00:00 2001 From: thoughtp0lice Date: Wed, 22 May 2024 20:54:00 -0500 Subject: [PATCH] improve Llama.eval efficiency --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 043fb2a6e..6dad650c6 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -562,12 +562,12 @@ def eval(self, tokens: Sequence[int]): if self.context_params.logits_all: rows = n_tokens cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] + logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, )) self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits else: rows = 1 cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] + logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, )) self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits # Update n_tokens self.n_tokens += n_tokens