From 7abb8671f8c18aecb4f02ff4a60b19106d91725d Mon Sep 17 00:00:00 2001 From: Anik Bhattacharjee Date: Thu, 4 Apr 2024 00:38:48 -0400 Subject: [PATCH] Knowledge doc chunking follow up Follow up to #772 Signed-off-by: Anik Bhattacharjee --- cli/generator/generate_data.py | 25 +++++++--------------- cli/lab.py | 9 ++++++++ cli/utils.py | 38 ++++++++++++++++++---------------- tests/test_lab_generate.py | 2 ++ 4 files changed, 39 insertions(+), 35 deletions(-) diff --git a/cli/generator/generate_data.py b/cli/generator/generate_data.py index ba2e569565..3b0008bb15 100755 --- a/cli/generator/generate_data.py +++ b/cli/generator/generate_data.py @@ -20,7 +20,6 @@ import yaml # Local -from .. import config from ..utils import chunk_document, get_documents, get_taxonomy_diff from . import utils @@ -339,6 +338,7 @@ def generate_data( console_output=True, api_key: Optional[str] = None, chunk_word_count=None, + server_ctx_size=None, ): seed_instruction_data = [] generate_start = time.time() @@ -364,20 +364,17 @@ def generate_data( def unescape(s): return bytes(s, "utf-8").decode("utf-8") - placeholder = seed_instruction_data[0]["document"] - if placeholder: - documents = chunk_document( - documents=placeholder, - max_context_size=config.MAX_CONTEXT_SIZE, - chunk_word_count=chunk_word_count, - ) - test_data = [] for seed_example in seed_instruction_data: user = seed_example["instruction"] - if placeholder: - seed_example["document"] = documents + documents = seed_example["document"] + if documents: + seed_example["document"] = chunk_document( + documents=documents, + server_ctx_size=server_ctx_size, + chunk_word_count=chunk_word_count, + ) if len(seed_example["input"]) > 0: user += "\n" + seed_example["input"] @@ -597,12 +594,6 @@ def read_taxonomy_file(logger, file_path, yaml_rules: Optional[str] = None): documents = get_documents(documents) logger.info("Content from git repo fetched") - # cfg = config.get_default_config() - # documents = chunk_document( - # documents=documents, - # max_context_size=cfg.serve.max_ctx_size, - # chunk_word_count=chunk_word_count, - # ) for t in get_seed_examples(contents): q = t["question"] a = t["answer"] diff --git a/cli/lab.py b/cli/lab.py index 4697c6347c..d1b37a31bf 100755 --- a/cli/lab.py +++ b/cli/lab.py @@ -388,6 +388,13 @@ def serve(ctx, model_path, gpu_layers, num_threads, max_ctx_size): show_default=True, help="Rules file for YAML linting", ) +@click.option( + "--server-ctx-size", + type=click.INT, + default=config.MAX_CONTEXT_SIZE, + show_default=True, + help="The context size is the maximum number of tokens the server will consider.", +) @click.pass_context def generate( ctx, @@ -403,6 +410,7 @@ def generate( api_key, yaml_rules, chunk_word_count, + server_ctx_size, ): """Generates synthetic data to enhance your example data""" # pylint: disable=C0415 @@ -450,6 +458,7 @@ def generate( console_output=not quiet, yaml_rules=yaml_rules, chunk_word_count=chunk_word_count, + server_ctx_size=server_ctx_size, ) except GenerateException as exc: click.secho( diff --git a/cli/utils.py b/cli/utils.py index 1df2a97a8c..6decc3c45a 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -216,33 +216,35 @@ def get_documents(input_pattern: Dict[str, Union[str, List[str]]]) -> List[str]: shutil.rmtree(temp_dir) -def chunk_document(documents: List, max_context_size, chunk_word_count) -> List[str]: +def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[str]: """ Iterates over the documents and splits them into chunks based on the word count provided by the user. Args: - document (dict): List of documents retrieved from git (can also consist of a single document) - max_context_size (int): Defaults to 4096 + documents (dict): List of documents retrieved from git (can also consist of a single document). + server_ctx_size (int): Context window size of server. chunk_word_count (int): Maximum number of words to chunk a document. Returns: List[str]: List of chunked documents. """ - token_size = int(chunk_word_count * 1.3) # 1 word ~ 1.3 token - content = [] - if token_size < int(max_context_size - 1024): - text_splitter = RecursiveCharacterTextSplitter( - separators=["\n\n", "\n"], - chunk_size=int(token_size * 4), # 1 token ~ 4 English character - chunk_overlap=100, - ) - - for docs in documents: - temp = text_splitter.create_documents([docs]) - content.extend([item.page_content for item in temp]) - - else: + no_tokens_per_doc = int(chunk_word_count * 1.3) # 1 word ~ 1.3 token + if no_tokens_per_doc > int(server_ctx_size - 1024): logger.error( - "Error: Given word count exceeds the required chunk limit i.e. 2400" + "Error: {}".format( + str( + f"Given word count per doc will exceed the server context window size {server_ctx_size}" + ) + ) ) sys.exit() + content = [] + text_splitter = RecursiveCharacterTextSplitter( + separators=["\n\n", "\n"], + chunk_size=int(no_tokens_per_doc * 4), # 1 token ~ 4 English character + chunk_overlap=100, + ) + + for docs in documents: + temp = text_splitter.create_documents([docs]) + content.extend([item.page_content for item in temp]) return content diff --git a/tests/test_lab_generate.py b/tests/test_lab_generate.py index fa938756c6..3cd84b005e 100644 --- a/tests/test_lab_generate.py +++ b/tests/test_lab_generate.py @@ -179,6 +179,7 @@ def test_OpenAI_server_error(self, get_instructions_from_model): rouge_threshold=0.9, console_output=True, chunk_word_count=1000, + server_ctx_size=4096, ) self.assertIn( "There was a problem connecting to the OpenAI server", @@ -210,6 +211,7 @@ def test_no_error(self, get_instructions_from_model): rouge_threshold=0.9, console_output=True, chunk_word_count=1000, + server_ctx_size=4096, ) get_instructions_from_model.assert_called_once() expected_files = [