Description
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
2.17 (for my env) and 2.18 (colab minimal example)
Custom code
Yes
OS platform and distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
I try to apply thresholds on model outputs by using a StaticHashTable. Since I also use multiple custom layers and I want to package them for tf-serve i decided to go with keras.export.ExportArchive() as described in the docs here.
Since ExportArchive use saved_model format, it allows me to completely package all custom code with the model plus having a way to pre-process inputs and post-process outputs of the model (applying the thresholds).
This is code that minimally extends the example (so it can run actually) from the docs which I referenced above:
import tensorflow as tf
import keras
threshold_keys = tf.constant(["a", "b"], dtype=tf.string, name="threshold_keys")
threshold_values = tf.constant([0.5, 0.7], dtype=tf.float32, name="threshold_values")
initializer = tf.lookup.KeyValueTensorInitializer(threshold_keys, threshold_values)
lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)
model = keras.layers.Dense(1)
model(tf.constant([[0.5]]))
export_archive = keras.export.ExportArchive()
model_fn = export_archive.track_and_add_endpoint(
"model_fn",
model,
input_signature=[tf.TensorSpec(shape=(None, 1), dtype=tf.float32)]
)
export_archive.track(lookup_table)
@tf.function()
def serving_fn(x):
x = lookup_table.lookup(x)
return model_fn(x)
x = tf.constant([["a"]])
serving_fn(x)
export_archive.add_endpoint(name="serve", fn=serving_fn)
export_archive.write_out("larifari") #<--- here i get the exception
You can also find that code at the bottom of the shared colab notebook.
The model I built can successfully do predictions but when I try to save it, i get the error that some tensor can not be properly tracked:
AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly. See the information below:
Function name = b'__inference_signature_wrapper_294'
Captured Tensor = <ResourceHandle(name="101", device="/job:localhost/replica:0/task:0/device:CPU:0", container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[ ]")>
Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at 0x79e857dbcf10>
Internal Tensor = Tensor("284:0", shape=(), dtype=resource)
It seems that the tensor which is not tracked is referenced by the StaticHashTable and probably is the "resource_handle" of that class. What i also tried:
- having the StaticHashTable being constructed in the init method of the wrapper model and setting it as a class property
- also exposing the resource_handle of the StaticHashTable and tracking it explicitly and setting it as a class property
- try to use self._track_trackable to force tracking
Can you tell me how to export this model for tf-serve usage with custom code and a function for pre- and post-processing that uses a StaticHashTable?
I tried to solve this for 2 days and I am stuck.
Standalone code to reproduce the issue
https://colab.research.google.com/drive/1EpYpn6EhvY4PyKX65xrPd_QZVo3Dw2uO#scrollTo=rxp3GhfJRULi
Relevant log output
AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly. See the information below:
Function name = b'__inference_signature_wrapper_294'
Captured Tensor = <ResourceHandle(name="101", device="/job:localhost/replica:0/task:0/device:CPU:0", container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[ ]")>
Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at 0x79e857dbcf10>
Internal Tensor = Tensor("284:0", shape=(), dtype=resource)