Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ca47835

Browse filesBrowse files
tswastgoogle-labs-jules[bot]gemini-code-assist[bot]
authored
feat: add support for hparam_range and hparam_candidates to bigframes.bigquery.create_model (#16640)
Fixes internal issue b/501171054 🦕 --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: tswast <247555+tswast@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent fe5245b commit ca47835
Copy full SHA for ca47835

5 files changed

+88-16Lines changed: 88 additions & 16 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/bigframes/bigframes/bigquery/__init__.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/__init__.py
+9-1Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@
8787
to_json,
8888
to_json_string,
8989
)
90-
from bigframes.bigquery._operations.mathematical import rand
90+
from bigframes.bigquery._operations.mathematical import (
91+
hparam_candidates,
92+
hparam_range,
93+
rand,
94+
)
9195
from bigframes.bigquery._operations.search import create_vector_index, vector_search
9296
from bigframes.bigquery._operations.sql import sql_scalar
9397
from bigframes.bigquery._operations.struct import struct
@@ -130,6 +134,8 @@
130134
to_json,
131135
to_json_string,
132136
# mathematical ops
137+
hparam_candidates,
138+
hparam_range,
133139
rand,
134140
# search ops
135141
create_vector_index,
@@ -187,6 +193,8 @@
187193
"to_json",
188194
"to_json_string",
189195
# mathematical ops
196+
"hparam_candidates",
197+
"hparam_range",
190198
"rand",
191199
# search ops
192200
"create_vector_index",
Collapse file

‎packages/bigframes/bigframes/bigquery/_operations/mathematical.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/_operations/mathematical.py
+70Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Sequence
18+
1719
import bigframes.core.col
1820
import bigframes.core.expression
1921
from bigframes import dtypes
@@ -51,3 +53,71 @@ def rand() -> bigframes.core.col.Expression:
5153
is_deterministic=False,
5254
)
5355
return bigframes.core.col.Expression(bigframes.core.expression.OpExpression(op, ()))
56+
57+
58+
def hparam_range(min: float, max: float) -> bigframes.core.col.Expression:
59+
"""
60+
Defines the minimum and maximum bounds of the search space of continuous
61+
values for a hyperparameter.
62+
63+
**Examples:**
64+
65+
>>> import bigframes.pandas as bpd
66+
>>> import bigframes.bigquery as bbq
67+
>>> # Specify a range of values for a hyperparameter.
68+
>>> learn_rate = bbq.hparam_range(0.0001, 1.0)
69+
70+
Args:
71+
min (float or int):
72+
The minimum bound of the search space.
73+
max (float or int):
74+
The maximum bound of the search space.
75+
76+
Returns:
77+
bigframes.pandas.api.typing.Expression:
78+
An expression that can be used in model options.
79+
"""
80+
min_expr = bigframes.core.expression.const(min)
81+
max_expr = bigframes.core.expression.const(max)
82+
83+
op = ops.SqlScalarOp(
84+
_output_type=dtypes.FLOAT_DTYPE,
85+
sql_template="HPARAM_RANGE({0}, {1})",
86+
is_deterministic=True,
87+
)
88+
return bigframes.core.col.Expression(
89+
bigframes.core.expression.OpExpression(op, (min_expr, max_expr))
90+
)
91+
92+
93+
def hparam_candidates(
94+
candidates: Sequence[float | str],
95+
) -> bigframes.core.col.Expression:
96+
"""
97+
Specifies the set of discrete values for the hyperparameter.
98+
99+
**Examples:**
100+
101+
>>> import bigframes.pandas as bpd
102+
>>> import bigframes.bigquery as bbq
103+
>>> # Specify a set of values for a hyperparameter.
104+
>>> optimizer = bbq.hparam_candidates(['ADAGRAD', 'SGD', 'FTRL'])
105+
106+
Args:
107+
candidates (Sequence[float | str]):
108+
The set of discrete values for the hyperparameter.
109+
110+
Returns:
111+
bigframes.pandas.api.typing.Expression:
112+
An expression that can be used in model options.
113+
"""
114+
candidates_expr = bigframes.core.expression.const(tuple(candidates))
115+
116+
op = ops.SqlScalarOp(
117+
_output_type=dtypes.STRING_DTYPE,
118+
sql_template="HPARAM_CANDIDATES({0})",
119+
is_deterministic=True,
120+
)
121+
return bigframes.core.col.Expression(
122+
bigframes.core.expression.OpExpression(op, (candidates_expr,))
123+
)
Collapse file

‎packages/bigframes/tests/unit/core/sql/snapshots/test_ml/test_create_model_expression_option/create_model_expression_option.sql‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/unit/core/sql/snapshots/test_ml/test_create_model_expression_option/create_model_expression_option.sql
-3Lines changed: 0 additions & 3 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL `my_model`
2+
OPTIONS(model_type = 'LINEAR_REG', learn_rate = HPARAM_RANGE(0.0001, 1.0), optimizer = HPARAM_CANDIDATES(['ADAGRAD', 'SGD']))
3+
AS SELECT * FROM t
Collapse file

‎packages/bigframes/tests/unit/core/sql/test_ml.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/unit/core/sql/test_ml.py
+6-12Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616

17+
import bigframes.bigquery as bbq
1718
import bigframes.core.col as col
1819
import bigframes.core.expression as ex
1920
import bigframes.core.sql.ml
@@ -101,24 +102,17 @@ def test_create_model_list_option(snapshot):
101102
snapshot.assert_match(sql, "create_model_list_option.sql")
102103

103104

104-
def test_create_model_expression_option(snapshot):
105-
# An expression that calls a function on a literal value
106-
# e.g. 0.1 * 10
107-
literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE)
108-
multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE)
109-
math_expr = col.Expression(
110-
ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr))
111-
)
112-
105+
def test_create_model_hparam_tuning(snapshot):
113106
sql = bigframes.core.sql.ml.create_model_ddl(
114107
model_name="my_model",
115108
options={
116-
"l2_reg": math_expr,
117-
"booster_type": "gbtree",
109+
"model_type": "LINEAR_REG",
110+
"learn_rate": bbq.hparam_range(0.0001, 1.0),
111+
"optimizer": bbq.hparam_candidates(["ADAGRAD", "SGD"]),
118112
},
119113
training_data="SELECT * FROM t",
120114
)
121-
snapshot.assert_match(sql, "create_model_expression_option.sql")
115+
snapshot.assert_match(sql, "create_model_hparam_tuning.sql")
122116

123117

124118
def test_evaluate_model_basic(snapshot):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.