From 254f19444b909bd08c30444d7eae90d6b36658eb Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 2 Jun 2023 20:13:05 -0700 Subject: [PATCH] support for falcon --- pgml-extension/requirements.txt | 1 + pgml-extension/src/bindings/transformers.py | 4 ++++ pgml-extension/src/bindings/transformers.rs | 17 ++++++++++++----- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pgml-extension/requirements.txt b/pgml-extension/requirements.txt index 405dc0a70..95bf71d35 100644 --- a/pgml-extension/requirements.txt +++ b/pgml-extension/requirements.txt @@ -19,3 +19,4 @@ tqdm==4.65.0 transformers==4.29.2 xgboost==1.7.5 langchain==0.0.180 +einops==0.6.1 diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 8334532d1..79e61a5d6 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -97,6 +97,10 @@ def transform(task, args, inputs): ensure_device(task) convert_dtype(task) + model = task.get("model", None) + if model and "tokenizer" not in task: + task["tokenizer"] = AutoTokenizer.from_pretrained(model) + if key not in __cache_transform_pipeline_by_task: __cache_transform_pipeline_by_task[key] = transformers.pipeline(**task) pipe = __cache_transform_pipeline_by_task[key] diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 8f296d812..65c24bcd6 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -35,17 +35,24 @@ pub fn transform( let results = Python::with_gil(|py| -> String { let transform: Py = PY_MODULE.getattr(py, "transform").unwrap().into(); - transform + let result = transform .call1( py, PyTuple::new( py, &[task.into_py(py), args.into_py(py), inputs.into_py(py)], ), - ) - .unwrap() - .extract(py) - .unwrap() + ); + + let result = match result { + Err(e) => { + let traceback = e.traceback(py).unwrap().format().unwrap(); + error!("{traceback} {e}") + } + Ok(o) => o.extract(py).unwrap(), + }; + + result }); serde_json::from_str(&results).unwrap() }