From 3289e8203de3a26119299594311aeb326224d4e5 Mon Sep 17 00:00:00 2001 From: aajha Date: Tue, 2 Apr 2024 00:12:26 +0530 Subject: [PATCH 1/2] add langchain for chunking Signed-off-by: aajha --- cli/config.py | 2 ++ cli/generator/generate_data.py | 25 ++++++++++++++++++++++++- cli/lab.py | 9 +++++++++ cli/utils.py | 34 ++++++++++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_config.py | 2 ++ tests/test_lab_generate.py | 2 ++ 7 files changed, 74 insertions(+), 1 deletion(-) diff --git a/cli/config.py b/cli/config.py index bb1aaabd6e..9c25bff898 100644 --- a/cli/config.py +++ b/cli/config.py @@ -16,6 +16,7 @@ DEFAULT_YAML_RULES = "yaml_rules.yaml" # TODO: these constants should be removed, they should not leak out DEFAULT_NUM_CPUS = 10 +DEFAULT_CHUNK_WORD_COUNT = 1000 DEFAULT_NUM_INSTRUCTIONS = 100 DEFAULT_PROMPT_FILE = "prompt.txt" DEFAULT_GENERATED_FILES_OUTPUT_DIR = "generated" @@ -67,6 +68,7 @@ class _generate(BaseModel): # optional fields num_cpus: Optional[PositiveInt] = DEFAULT_NUM_CPUS + chunk_word_count: Optional[PositiveInt] = DEFAULT_CHUNK_WORD_COUNT num_instructions: Optional[PositiveInt] = DEFAULT_NUM_INSTRUCTIONS output_dir: Optional[StrictStr] = DEFAULT_GENERATED_FILES_OUTPUT_DIR prompt_file: Optional[StrictStr] = DEFAULT_PROMPT_FILE diff --git a/cli/generator/generate_data.py b/cli/generator/generate_data.py index 8868a6f044..8358817318 100755 --- a/cli/generator/generate_data.py +++ b/cli/generator/generate_data.py @@ -20,7 +20,8 @@ import yaml # Local -from ..utils import get_documents, get_taxonomy_diff +from .. import config +from ..utils import chunk_document, get_documents, get_taxonomy_diff from . import utils DEFAULT_PROMPT_TEMPLATE = """\ @@ -337,6 +338,7 @@ def generate_data( rouge_threshold: Optional[float] = None, console_output=True, api_key: Optional[str] = None, + chunk_word_count=None, ): seed_instruction_data = [] generate_start = time.time() @@ -362,9 +364,23 @@ def generate_data( def unescape(s): return bytes(s, "utf-8").decode("utf-8") + cfg = config.get_default_config() + + placeholder = seed_instruction_data[0]["document"] + if placeholder: + documents = chunk_document( + documents=placeholder, + max_context_size=cfg.serve.max_ctx_size, + chunk_word_count=chunk_word_count, + ) + test_data = [] for seed_example in seed_instruction_data: user = seed_example["instruction"] + + if placeholder: + seed_example["document"] = documents + if len(seed_example["input"]) > 0: user += "\n" + seed_example["input"] try: @@ -582,6 +598,13 @@ def read_taxonomy_file(logger, file_path, yaml_rules: Optional[str] = None): if documents: documents = get_documents(documents) logger.info("Content from git repo fetched") + + # cfg = config.get_default_config() + # documents = chunk_document( + # documents=documents, + # max_context_size=cfg.serve.max_ctx_size, + # chunk_word_count=chunk_word_count, + # ) for t in get_seed_examples(contents): q = t["question"] a = t["answer"] diff --git a/cli/lab.py b/cli/lab.py index 795c642544..4697c6347c 100755 --- a/cli/lab.py +++ b/cli/lab.py @@ -325,6 +325,13 @@ def serve(ctx, model_path, gpu_layers, num_threads, max_ctx_size): default=config.DEFAULT_NUM_CPUS, show_default=True, ) +@click.option( + "--chunk-word-count", + type=click.INT, + help="Number of words to chunk the document", + default=config.DEFAULT_CHUNK_WORD_COUNT, + show_default=True, +) @click.option( "--num-instructions", type=click.INT, @@ -395,6 +402,7 @@ def generate( endpoint_url, api_key, yaml_rules, + chunk_word_count, ): """Generates synthetic data to enhance your example data""" # pylint: disable=C0415 @@ -441,6 +449,7 @@ def generate( rouge_threshold=rouge_threshold, console_output=not quiet, yaml_rules=yaml_rules, + chunk_word_count=chunk_word_count, ) except GenerateException as exc: click.secho( diff --git a/cli/utils.py b/cli/utils.py index c071995aec..1beb0d870b 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -9,8 +9,10 @@ import re import shutil import subprocess +import sys # Third Party +from langchain_text_splitters import RecursiveCharacterTextSplitter import click import git import gitdb @@ -220,3 +222,35 @@ def get_documents(input_pattern: Dict[str, Union[str, List[str]]]) -> List[str]: # Cleanup: Remove the temporary directory if it exists if os.path.exists(temp_dir): shutil.rmtree(temp_dir) + + +def chunk_document(documents: List, max_context_size, chunk_word_count) -> List[str]: + """ + Iterates over the documents and splits them into chunks based on the word count provided by the user. + Args: + document (dict): List of documents retrieved from git (can also consist of a single document) + max_context_size (int): Defaults to 4096 + chunk_word_count (int): Maximum number of words to chunk a document. + Returns: + List[str]: List of chunked documents. + """ + token_size = int(chunk_word_count * 1.3) # 1 word ~ 1.3 token + content = [] + if token_size < int(max_context_size - 1024): + text_splitter = RecursiveCharacterTextSplitter( + separators=["\n\n", "\n"], + chunk_size=int(token_size * 4), # 1 token ~ 4 English character + chunk_overlap=100, + ) + + for docs in documents: + temp = text_splitter.create_documents([docs]) + content.extend([item.page_content for item in temp]) + + else: + logger.error( + "Error: Given word count exceeds the required chunk limit i.e. 2400" + ) + sys.exit() + + return content diff --git a/requirements.txt b/requirements.txt index 21e8e87e2a..f1dcd6b5f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,3 +30,4 @@ transformers>=4.30.0,<=4.38.2 trl>=0.7.11,<0.8.0 wandb>=0.16.4,<0.17.0 yamllint +langchain-text-splitters diff --git a/tests/test_config.py b/tests/test_config.py index 26547aea65..19ac037343 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -32,6 +32,7 @@ def _assert_defaults(self, cfg): self.assertEqual(cfg.generate.taxonomy_base, "origin/main") self.assertEqual(cfg.generate.num_cpus, 10) self.assertEqual(cfg.generate.num_instructions, 100) + self.assertEqual(cfg.generate.chunk_word_count, 1000) self.assertEqual(cfg.generate.output_dir, "generated") self.assertEqual(cfg.generate.prompt_file, "prompt.txt") self.assertEqual(cfg.generate.seed_file, "seed_tasks.json") @@ -88,6 +89,7 @@ def test_full_config(self): seed_file: seed_tasks.json taxonomy_base: origin/main taxonomy_path: taxonomy + chunk_word_count: 1000 serve: gpu_layers: -1 host_port: 127.0.0.1:8000 diff --git a/tests/test_lab_generate.py b/tests/test_lab_generate.py index 8e7426b7ec..fa938756c6 100644 --- a/tests/test_lab_generate.py +++ b/tests/test_lab_generate.py @@ -178,6 +178,7 @@ def test_OpenAI_server_error(self, get_instructions_from_model): prompt_file_path="prompt.txt", rouge_threshold=0.9, console_output=True, + chunk_word_count=1000, ) self.assertIn( "There was a problem connecting to the OpenAI server", @@ -208,6 +209,7 @@ def test_no_error(self, get_instructions_from_model): prompt_file_path="prompt.txt", rouge_threshold=0.9, console_output=True, + chunk_word_count=1000, ) get_instructions_from_model.assert_called_once() expected_files = [ From 42da613f7988811544b9c60410e579173dad665c Mon Sep 17 00:00:00 2001 From: aajha Date: Wed, 3 Apr 2024 21:43:39 +0530 Subject: [PATCH 2/2] use config default context size Signed-off-by: aajha --- cli/config.py | 1 + cli/generator/generate_data.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cli/config.py b/cli/config.py index 9c25bff898..27587066f5 100644 --- a/cli/config.py +++ b/cli/config.py @@ -14,6 +14,7 @@ DEFAULT_TAXONOMY_PATH = "taxonomy" DEFAULT_TAXONOMY_BASE = "origin/main" DEFAULT_YAML_RULES = "yaml_rules.yaml" +MAX_CONTEXT_SIZE = 4096 # TODO: these constants should be removed, they should not leak out DEFAULT_NUM_CPUS = 10 DEFAULT_CHUNK_WORD_COUNT = 1000 diff --git a/cli/generator/generate_data.py b/cli/generator/generate_data.py index 8358817318..ba2e569565 100755 --- a/cli/generator/generate_data.py +++ b/cli/generator/generate_data.py @@ -364,13 +364,11 @@ def generate_data( def unescape(s): return bytes(s, "utf-8").decode("utf-8") - cfg = config.get_default_config() - placeholder = seed_instruction_data[0]["document"] if placeholder: documents = chunk_document( documents=placeholder, - max_context_size=cfg.serve.max_ctx_size, + max_context_size=config.MAX_CONTEXT_SIZE, chunk_word_count=chunk_word_count, )