Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 07ee41c

Browse filesBrowse files
authored
Adding a way to clear GPU memory (#722)
1 parent 22d16cf commit 07ee41c
Copy full SHA for 07ee41c

File tree

Expand file treeCollapse file tree

4 files changed

+65
-0
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+65
-0
lines changed
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- src/api.rs:599
2+
-- pgml::api::clear_gpu_cache
3+
CREATE FUNCTION pgml."clear_gpu_cache"(
4+
"memory_usage" REAL DEFAULT NULL /* Option<f32> */
5+
) RETURNS bool /* bool */
6+
IMMUTABLE STRICT PARALLEL SAFE
7+
LANGUAGE c /* Rust */
8+
AS 'MODULE_PATHNAME', 'clear_gpu_cache_wrapper';

‎pgml-extension/src/api.rs

Copy file name to clipboardExpand all lines: pgml-extension/src/api.rs
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,29 @@ pub fn embed_batch(
580580
crate::bindings::transformers::embed(transformer, inputs, &kwargs.0)
581581
}
582582

583+
584+
/// Clears the GPU cache.
585+
///
586+
/// # Arguments
587+
///
588+
/// * `memory_usage` - Optional parameter indicating the memory usage percentage (0.0 -> 1.0)
589+
///
590+
/// # Returns
591+
///
592+
/// Returns `true` if the GPU cache was successfully cleared, `false` otherwise.
593+
/// # Example
594+
///
595+
/// ```sql
596+
/// SELECT pgml.clear_gpu_cache(memory_usage => 0.5);
597+
/// ```
598+
#[pg_extern(immutable, parallel_safe, name = "clear_gpu_cache")]
599+
pub fn clear_gpu_cache(
600+
memory_usage: default!(Option<f32>, "NULL")
601+
) -> bool {
602+
let memory_usage: Option<f32> = memory_usage.map(|memory_usage| memory_usage.try_into().unwrap());
603+
crate::bindings::transformers::clear_gpu_cache(memory_usage)
604+
}
605+
583606
#[pg_extern(immutable, parallel_safe)]
584607
pub fn chunk(
585608
splitter: &str,

‎pgml-extension/src/bindings/transformers.py

Copy file name to clipboardExpand all lines: pgml-extension/src/bindings/transformers.py
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,17 @@ def embed(transformer, inputs, kwargs):
131131

132132
return model.encode(inputs, **kwargs)
133133

134+
def clear_gpu_cache(memory_usage: None):
135+
if not torch.cuda.is_available():
136+
raise PgMLException(f"No GPU availables")
137+
138+
139+
mem_used = torch.cuda.memory_usage()
140+
if not memory_usage or mem_used >= int(memory_usage * 100.0):
141+
torch.cuda.empty_cache()
142+
return True
143+
return False
144+
134145

135146
def load_dataset(name, subset, limit: None, kwargs: "{}"):
136147
kwargs = orjson.loads(kwargs)

‎pgml-extension/src/bindings/transformers.rs

Copy file name to clipboardExpand all lines: pgml-extension/src/bindings/transformers.rs
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,26 @@ pub fn load_dataset(
311311

312312
num_rows
313313
}
314+
315+
pub fn clear_gpu_cache(
316+
memory_usage: Option<f32>
317+
) -> bool {
318+
319+
Python::with_gil(|py| -> bool {
320+
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap().into();
321+
clear_gpu_cache
322+
.call1(
323+
py,
324+
PyTuple::new(
325+
py,
326+
&[
327+
memory_usage.into_py(py),
328+
],
329+
),
330+
)
331+
.unwrap()
332+
.extract(py)
333+
.unwrap()
334+
})
335+
}
336+

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.