Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions 3 cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
DEFAULT_TAXONOMY_PATH = "taxonomy"
DEFAULT_TAXONOMY_BASE = "origin/main"
DEFAULT_YAML_RULES = "yaml_rules.yaml"
MAX_CONTEXT_SIZE = 4096
# TODO: these constants should be removed, they should not leak out
DEFAULT_NUM_CPUS = 10
DEFAULT_CHUNK_WORD_COUNT = 1000
DEFAULT_NUM_INSTRUCTIONS = 100
DEFAULT_PROMPT_FILE = "prompt.txt"
DEFAULT_GENERATED_FILES_OUTPUT_DIR = "generated"
Expand Down Expand Up @@ -67,6 +69,7 @@ class _generate(BaseModel):

# optional fields
num_cpus: Optional[PositiveInt] = DEFAULT_NUM_CPUS
chunk_word_count: Optional[PositiveInt] = DEFAULT_CHUNK_WORD_COUNT
num_instructions: Optional[PositiveInt] = DEFAULT_NUM_INSTRUCTIONS
output_dir: Optional[StrictStr] = DEFAULT_GENERATED_FILES_OUTPUT_DIR
prompt_file: Optional[StrictStr] = DEFAULT_PROMPT_FILE
Expand Down
23 changes: 22 additions & 1 deletion 23 cli/generator/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import yaml

# Local
from ..utils import get_documents, get_taxonomy_diff
from .. import config
from ..utils import chunk_document, get_documents, get_taxonomy_diff
from . import utils

DEFAULT_PROMPT_TEMPLATE = """\
Expand Down Expand Up @@ -337,6 +338,7 @@ def generate_data(
rouge_threshold: Optional[float] = None,
console_output=True,
api_key: Optional[str] = None,
chunk_word_count=None,
):
seed_instruction_data = []
generate_start = time.time()
Expand All @@ -362,9 +364,21 @@ def generate_data(
def unescape(s):
return bytes(s, "utf-8").decode("utf-8")

placeholder = seed_instruction_data[0]["document"]
if placeholder:
documents = chunk_document(
documents=placeholder,
max_context_size=config.MAX_CONTEXT_SIZE,
chunk_word_count=chunk_word_count,
)

test_data = []
for seed_example in seed_instruction_data:
user = seed_example["instruction"]

if placeholder:
seed_example["document"] = documents

if len(seed_example["input"]) > 0:
user += "\n" + seed_example["input"]
try:
Expand Down Expand Up @@ -582,6 +596,13 @@ def read_taxonomy_file(logger, file_path, yaml_rules: Optional[str] = None):
if documents:
documents = get_documents(documents)
logger.info("Content from git repo fetched")

# cfg = config.get_default_config()
# documents = chunk_document(
# documents=documents,
# max_context_size=cfg.serve.max_ctx_size,
# chunk_word_count=chunk_word_count,
# )
for t in get_seed_examples(contents):
q = t["question"]
a = t["answer"]
Expand Down
9 changes: 9 additions & 0 deletions 9 cli/lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def serve(ctx, model_path, gpu_layers, num_threads, max_ctx_size):
default=config.DEFAULT_NUM_CPUS,
show_default=True,
)
@click.option(
"--chunk-word-count",
type=click.INT,
help="Number of words to chunk the document",
default=config.DEFAULT_CHUNK_WORD_COUNT,
show_default=True,
)
@click.option(
"--num-instructions",
type=click.INT,
Expand Down Expand Up @@ -395,6 +402,7 @@ def generate(
endpoint_url,
api_key,
yaml_rules,
chunk_word_count,
):
"""Generates synthetic data to enhance your example data"""
# pylint: disable=C0415
Expand Down Expand Up @@ -441,6 +449,7 @@ def generate(
rouge_threshold=rouge_threshold,
console_output=not quiet,
yaml_rules=yaml_rules,
chunk_word_count=chunk_word_count,
)
except GenerateException as exc:
click.secho(
Expand Down
34 changes: 34 additions & 0 deletions 34 cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import re
import shutil
import subprocess
import sys

# Third Party
from langchain_text_splitters import RecursiveCharacterTextSplitter
import click
import git
import gitdb
Expand Down Expand Up @@ -220,3 +222,35 @@ def get_documents(input_pattern: Dict[str, Union[str, List[str]]]) -> List[str]:
# Cleanup: Remove the temporary directory if it exists
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)


def chunk_document(documents: List, max_context_size, chunk_word_count) -> List[str]:
"""
Iterates over the documents and splits them into chunks based on the word count provided by the user.
Args:
document (dict): List of documents retrieved from git (can also consist of a single document)
max_context_size (int): Defaults to 4096
chunk_word_count (int): Maximum number of words to chunk a document.
Returns:
List[str]: List of chunked documents.
"""
token_size = int(chunk_word_count * 1.3) # 1 word ~ 1.3 token
content = []
if token_size < int(max_context_size - 1024):
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=int(token_size * 4), # 1 token ~ 4 English character
chunk_overlap=100,
)

for docs in documents:
temp = text_splitter.create_documents([docs])
content.extend([item.page_content for item in temp])

else:
logger.error(
"Error: Given word count exceeds the required chunk limit i.e. 2400"
)
sys.exit()

return content
1 change: 1 addition & 0 deletions 1 requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ transformers>=4.30.0,<=4.38.2
trl>=0.7.11,<0.8.0
wandb>=0.16.4,<0.17.0
yamllint
langchain-text-splitters
2 changes: 2 additions & 0 deletions 2 tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _assert_defaults(self, cfg):
self.assertEqual(cfg.generate.taxonomy_base, "origin/main")
self.assertEqual(cfg.generate.num_cpus, 10)
self.assertEqual(cfg.generate.num_instructions, 100)
self.assertEqual(cfg.generate.chunk_word_count, 1000)
self.assertEqual(cfg.generate.output_dir, "generated")
self.assertEqual(cfg.generate.prompt_file, "prompt.txt")
self.assertEqual(cfg.generate.seed_file, "seed_tasks.json")
Expand Down Expand Up @@ -88,6 +89,7 @@ def test_full_config(self):
seed_file: seed_tasks.json
taxonomy_base: origin/main
taxonomy_path: taxonomy
chunk_word_count: 1000
serve:
gpu_layers: -1
host_port: 127.0.0.1:8000
Expand Down
2 changes: 2 additions & 0 deletions 2 tests/test_lab_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_OpenAI_server_error(self, get_instructions_from_model):
prompt_file_path="prompt.txt",
rouge_threshold=0.9,
console_output=True,
chunk_word_count=1000,
)
self.assertIn(
"There was a problem connecting to the OpenAI server",
Expand Down Expand Up @@ -208,6 +209,7 @@ def test_no_error(self, get_instructions_from_model):
prompt_file_path="prompt.txt",
rouge_threshold=0.9,
console_output=True,
chunk_word_count=1000,
)
get_instructions_from_model.assert_called_once()
expected_files = [
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.