Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

ExportArchive can not deal with StaticHashTable #93452

Copy link
Copy link
Open
@PowerToThePeople111

Description

@PowerToThePeople111
Issue body actions

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

2.17 (for my env) and 2.18 (colab minimal example)

Custom code

Yes

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

I try to apply thresholds on model outputs by using a StaticHashTable. Since I also use multiple custom layers and I want to package them for tf-serve i decided to go with keras.export.ExportArchive() as described in the docs here.
Since ExportArchive use saved_model format, it allows me to completely package all custom code with the model plus having a way to pre-process inputs and post-process outputs of the model (applying the thresholds).

This is code that minimally extends the example (so it can run actually) from the docs which I referenced above:

import tensorflow as tf
import keras

threshold_keys = tf.constant(["a", "b"], dtype=tf.string, name="threshold_keys")
threshold_values = tf.constant([0.5, 0.7], dtype=tf.float32, name="threshold_values")
initializer = tf.lookup.KeyValueTensorInitializer(threshold_keys, threshold_values)
lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)

model = keras.layers.Dense(1)
model(tf.constant([[0.5]]))

export_archive = keras.export.ExportArchive()
model_fn = export_archive.track_and_add_endpoint(
    "model_fn",
    model, 
    input_signature=[tf.TensorSpec(shape=(None, 1), dtype=tf.float32)]
)

export_archive.track(lookup_table)

@tf.function()
def serving_fn(x):
    x = lookup_table.lookup(x)
    return model_fn(x)

x = tf.constant([["a"]])
serving_fn(x)
export_archive.add_endpoint(name="serve", fn=serving_fn)

export_archive.write_out("larifari") #<--- here i get the exception

You can also find that code at the bottom of the shared colab notebook.
The model I built can successfully do predictions but when I try to save it, i get the error that some tensor can not be properly tracked:

AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly. See the information below:
Function name = b'__inference_signature_wrapper_294'
Captured Tensor = <ResourceHandle(name="101", device="/job:localhost/replica:0/task:0/device:CPU:0", container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[  ]")>
Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at 0x79e857dbcf10>
Internal Tensor = Tensor("284:0", shape=(), dtype=resource)

It seems that the tensor which is not tracked is referenced by the StaticHashTable and probably is the "resource_handle" of that class. What i also tried:

  • having the StaticHashTable being constructed in the init method of the wrapper model and setting it as a class property
  • also exposing the resource_handle of the StaticHashTable and tracking it explicitly and setting it as a class property
  • try to use self._track_trackable to force tracking

Can you tell me how to export this model for tf-serve usage with custom code and a function for pre- and post-processing that uses a StaticHashTable?

I tried to solve this for 2 days and I am stuck.

Standalone code to reproduce the issue

https://colab.research.google.com/drive/1EpYpn6EhvY4PyKX65xrPd_QZVo3Dw2uO#scrollTo=rxp3GhfJRULi

Relevant log output

AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly. See the information below:
Function name = b'__inference_signature_wrapper_294'
Captured Tensor = <ResourceHandle(name="101", device="/job:localhost/replica:0/task:0/device:CPU:0", container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[  ]")>
Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at 0x79e857dbcf10>
Internal Tensor = Tensor("284:0", shape=(), dtype=resource)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    Morty Proxy This is a proxified and sanitized view of the page, visit original site.