Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ed2d072

Browse filesBrowse files
authored
support for falcon (#676)
1 parent 32d18b2 commit ed2d072
Copy full SHA for ed2d072

File tree

3 files changed

+17
-5
lines changed
Filter options

3 files changed

+17
-5
lines changed

‎pgml-extension/requirements.txt

Copy file name to clipboardExpand all lines: pgml-extension/requirements.txt
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ tqdm==4.65.0
1919
transformers==4.29.2
2020
xgboost==1.7.5
2121
langchain==0.0.180
22+
einops==0.6.1

‎pgml-extension/src/bindings/transformers.py

Copy file name to clipboardExpand all lines: pgml-extension/src/bindings/transformers.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def transform(task, args, inputs):
9797
ensure_device(task)
9898
convert_dtype(task)
9999

100+
model = task.get("model", None)
101+
if model and "tokenizer" not in task:
102+
task["tokenizer"] = AutoTokenizer.from_pretrained(model)
103+
100104
if key not in __cache_transform_pipeline_by_task:
101105
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
102106
pipe = __cache_transform_pipeline_by_task[key]

‎pgml-extension/src/bindings/transformers.rs

Copy file name to clipboardExpand all lines: pgml-extension/src/bindings/transformers.rs
+12-5Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,24 @@ pub fn transform(
3535
let results = Python::with_gil(|py| -> String {
3636
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform").unwrap().into();
3737

38-
transform
38+
let result = transform
3939
.call1(
4040
py,
4141
PyTuple::new(
4242
py,
4343
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
4444
),
45-
)
46-
.unwrap()
47-
.extract(py)
48-
.unwrap()
45+
);
46+
47+
let result = match result {
48+
Err(e) => {
49+
let traceback = e.traceback(py).unwrap().format().unwrap();
50+
error!("{traceback} {e}")
51+
}
52+
Ok(o) => o.extract(py).unwrap(),
53+
};
54+
55+
result
4956
});
5057
serde_json::from_str(&results).unwrap()
5158
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.