Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cad035c

Browse filesBrowse files
yeesiancopybara-github
authored andcommitted
feat: Add enable_tracing to LangchainAgent.
PiperOrigin-RevId: 641955580
1 parent a78a35e commit cad035c
Copy full SHA for cad035c

4 files changed

+194-1Lines changed: 194 additions & 1 deletion

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎setup.py‎

Copy file name to clipboardExpand all lines: setup.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@
140140

141141
reasoning_engine_extra_require = [
142142
"cloudpickle >= 2.2.1, < 4.0",
143+
"opentelemetry-sdk < 2",
144+
"opentelemetry-exporter-gcp-trace < 2",
143145
"pydantic >= 2.6.3, < 3",
144146
]
145147

@@ -149,9 +151,10 @@
149151
]
150152

151153
langchain_extra_require = [
152-
"langchain >= 0.1.16, < 0.2",
154+
"langchain >= 0.1.16, < 0.3",
153155
"langchain-core < 0.2",
154156
"langchain-google-vertexai < 2",
157+
"openinference-instrumentation-langchain >= 0.1.19, < 0.2",
155158
]
156159

157160
langchain_testing_extra_require = list(
Collapse file

‎tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py‎

Copy file name to clipboardExpand all lines: tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py
+78Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vertexai.preview import reasoning_engines
2424
from vertexai.preview.generative_models import grounding
2525
from vertexai.generative_models import Tool
26+
from vertexai.reasoning_engines import _utils
2627
import pytest
2728

2829

@@ -89,6 +90,48 @@ def mock_chatvertexai():
8990
yield model_mock
9091

9192

93+
@pytest.fixture
94+
def cloud_trace_exporter_mock():
95+
with mock.patch.object(
96+
_utils,
97+
"_import_cloud_trace_exporter_or_warn",
98+
) as cloud_trace_exporter_mock:
99+
yield cloud_trace_exporter_mock
100+
101+
102+
@pytest.fixture
103+
def tracer_provider_mock():
104+
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
105+
yield tracer_provider_mock
106+
107+
108+
@pytest.fixture
109+
def simple_span_processor_mock():
110+
with mock.patch(
111+
"opentelemetry.sdk.trace.export.SimpleSpanProcessor"
112+
) as simple_span_processor_mock:
113+
yield simple_span_processor_mock
114+
115+
116+
@pytest.fixture
117+
def langchain_instrumentor_mock():
118+
with mock.patch.object(
119+
_utils,
120+
"_import_openinference_langchain_or_warn",
121+
) as langchain_instrumentor_mock:
122+
yield langchain_instrumentor_mock
123+
124+
125+
@pytest.fixture
126+
def langchain_instrumentor_none_mock():
127+
with mock.patch.object(
128+
_utils,
129+
"_import_openinference_langchain_or_warn",
130+
) as langchain_instrumentor_mock:
131+
langchain_instrumentor_mock.return_value = None
132+
yield langchain_instrumentor_mock
133+
134+
92135
@pytest.mark.usefixtures("google_auth_mock")
93136
class TestLangchainAgent:
94137
def setup_method(self):
@@ -175,6 +218,41 @@ def test_query(self, langchain_dump_mock):
175218
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
176219
)
177220

221+
@pytest.mark.usefixtures("caplog")
222+
def test_enable_tracing(
223+
self,
224+
caplog,
225+
cloud_trace_exporter_mock,
226+
tracer_provider_mock,
227+
simple_span_processor_mock,
228+
langchain_instrumentor_mock,
229+
):
230+
agent = reasoning_engines.LangchainAgent(
231+
model=_TEST_MODEL,
232+
prompt=self.prompt,
233+
output_parser=self.output_parser,
234+
enable_tracing=True,
235+
)
236+
assert agent._instrumentor is None
237+
agent.set_up()
238+
assert agent._instrumentor is not None
239+
assert (
240+
"enable_tracing=True but proceeding with tracing disabled"
241+
not in caplog.text
242+
)
243+
244+
@pytest.mark.usefixtures("caplog")
245+
def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock):
246+
agent = reasoning_engines.LangchainAgent(
247+
model=_TEST_MODEL,
248+
prompt=self.prompt,
249+
output_parser=self.output_parser,
250+
enable_tracing=True,
251+
)
252+
assert agent._instrumentor is None
253+
agent.set_up()
254+
assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
255+
178256

179257
class TestConvertToolsOrRaise:
180258
def test_convert_tools_or_raise(self, vertexai_init_mock):
Collapse file

‎vertexai/preview/reasoning_engines/templates/langchain.py‎

Copy file name to clipboardExpand all lines: vertexai/preview/reasoning_engines/templates/langchain.py
+45Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __init__(
236236
runnable_kwargs: Optional[Mapping[str, Any]] = None,
237237
model_builder: Optional[Callable] = None,
238238
runnable_builder: Optional[Callable] = None,
239+
enable_tracing: bool = False,
239240
):
240241
"""Initializes the LangchainAgent.
241242
@@ -349,6 +350,9 @@ def __init__(
349350
for customizing the orchestration logic of the Agent based on
350351
the model returned by `model_builder` and the rest of the input
351352
arguments.
353+
enable_tracing (bool):
354+
Optional. Whether to enable tracing in Cloud Trace. Defaults to
355+
False.
352356
353357
Raises:
354358
TypeError: If there is an invalid tool (e.g. function with an input
@@ -376,6 +380,8 @@ def __init__(
376380
self._model_builder = model_builder
377381
self._runnable = None
378382
self._runnable_builder = runnable_builder
383+
self._instrumentor = None
384+
self._enable_tracing = enable_tracing
379385

380386
def set_up(self):
381387
"""Sets up the agent for execution of queries at runtime.
@@ -387,6 +393,44 @@ def set_up(self):
387393
the ReasoningEngine service for deployment, as it initializes clients
388394
that can not be serialized.
389395
"""
396+
if self._enable_tracing:
397+
from vertexai.reasoning_engines import _utils
398+
399+
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
400+
openinference_langchain = _utils._import_openinference_langchain_or_warn()
401+
opentelemetry = _utils._import_opentelemetry_or_warn()
402+
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
403+
if all(
404+
(
405+
cloud_trace_exporter,
406+
openinference_langchain,
407+
opentelemetry,
408+
opentelemetry_sdk_trace,
409+
)
410+
):
411+
tracer_provider = opentelemetry.trace.get_tracer_provider()
412+
if tracer_provider and _utils._is_noop_tracer_provider(tracer_provider):
413+
# Set a trace provider if it has not been set.
414+
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
415+
project_id=self._project,
416+
)
417+
span_processor = opentelemetry_sdk_trace.export.SimpleSpanProcessor(
418+
span_exporter=span_exporter,
419+
)
420+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
421+
active_span_processor=span_processor,
422+
)
423+
opentelemetry.trace.set_tracer_provider(tracer_provider)
424+
self._instrumentor = openinference_langchain.LangChainInstrumentor()
425+
self._instrumentor.instrument()
426+
else:
427+
from google.cloud.aiplatform import base
428+
429+
_LOGGER = base.Logger(__name__)
430+
_LOGGER.warning(
431+
"enable_tracing=True but proceeding with tracing disabled "
432+
"because not all packages for tracing have been installed"
433+
)
390434
model_builder = self._model_builder or _default_model_builder
391435
self._model = model_builder(
392436
model_name=self._model_name,
@@ -422,6 +466,7 @@ def clone(self) -> "LangchainAgent":
422466
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
423467
model_builder=self._model_builder,
424468
runnable_builder=self._runnable_builder,
469+
enable_tracing=self._enable_tracing,
425470
)
426471

427472
def query(
Collapse file

‎vertexai/reasoning_engines/_utils.py‎

Copy file name to clipboardExpand all lines: vertexai/reasoning_engines/_utils.py
+67Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import proto
2323

24+
from google.cloud.aiplatform import base
2425
from google.protobuf import struct_pb2
2526
from google.protobuf import json_format
2627

@@ -36,6 +37,8 @@
3637

3738
JsonDict = Dict[str, Any]
3839

40+
_LOGGER = base.Logger(__name__)
41+
3942

4043
def to_proto(
4144
obj: Union[JsonDict, proto.Message],
@@ -195,6 +198,14 @@ def generate_schema(
195198
return schema
196199

197200

201+
def _is_noop_tracer_provider(tracer_provider) -> bool:
202+
"""Returns True if the tracer_provider is Proxy or NoOp."""
203+
opentelemetry = _import_opentelemetry_or_warn()
204+
ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
205+
NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
206+
return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))
207+
208+
198209
def _import_cloud_storage_or_raise() -> types.ModuleType:
199210
"""Tries to import the Cloud Storage module."""
200211
try:
@@ -233,3 +244,59 @@ def _import_pydantic_or_raise() -> types.ModuleType:
233244
"'pip install google-cloud-aiplatform[reasoningengine]'."
234245
) from e
235246
return pydantic
247+
248+
249+
def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
250+
"""Tries to import the opentelemetry module."""
251+
try:
252+
import opentelemetry # noqa:F401
253+
254+
return opentelemetry
255+
except ImportError:
256+
_LOGGER.warning(
257+
"opentelemetry-sdk is not installed. Please call "
258+
"'pip install google-cloud-aiplatform[reasoningengine]'."
259+
)
260+
return None
261+
262+
263+
def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
264+
"""Tries to import the opentelemetry.sdk.trace module."""
265+
try:
266+
import opentelemetry.sdk.trace # noqa:F401
267+
268+
return opentelemetry.sdk.trace
269+
except ImportError:
270+
_LOGGER.warning(
271+
"opentelemetry-sdk is not installed. Please call "
272+
"'pip install google-cloud-aiplatform[reasoningengine]'."
273+
)
274+
return None
275+
276+
277+
def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
278+
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
279+
try:
280+
import opentelemetry.exporter.cloud_trace # noqa:F401
281+
282+
return opentelemetry.exporter.cloud_trace
283+
except ImportError:
284+
_LOGGER.warning(
285+
"opentelemetry-exporter-gcp-trace is not installed. Please "
286+
"call 'pip install google-cloud-aiplatform[langchain]'."
287+
)
288+
return None
289+
290+
291+
def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
292+
"""Tries to import the openinference.instrumentation.langchain module."""
293+
try:
294+
import openinference.instrumentation.langchain # noqa:F401
295+
296+
return openinference.instrumentation.langchain
297+
except ImportError:
298+
_LOGGER.warning(
299+
"openinference-instrumentation-langchain is not installed. Please "
300+
"call 'pip install google-cloud-aiplatform[langchain]'."
301+
)
302+
return None

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.