Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Latest commit

 

History

History
History
115 lines (94 loc) · 3.99 KB

File metadata and controls

115 lines (94 loc) · 3.99 KB
Copy raw file
Download raw file
Open symbols panel
Edit and raw actions
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import h5py
import numpy as np
import argparse
import uuid
import model_pb2
def quantize_arr(arr):
"""Quantization based on linear rescaling over min/max range.
"""
min_val, max_val = np.min(arr), np.max(arr)
if max_val - min_val > 0:
quantized = np.round(255 * (arr - min_val) / (max_val - min_val))
else:
quantized = np.zeros(arr.shape)
quantized = quantized.astype(np.uint8)
min_val = min_val.astype(np.float32)
max_val = max_val.astype(np.float32)
return quantized, min_val, max_val
class Encoder:
"""Encoder class.
Takes as input a Keras model saved in hdf5 format that includes the model architecture with the weights.
This is the resulting file from running the command:
```
model.save('my_model.h5')
```
See https://keras.io/getting-started/faq/#savingloading-whole-models-architecture-weights-optimizer-state
"""
def __init__(self, hdf5_model_filepath, name, quantize):
if not hdf5_model_filepath:
raise Exception('hdf5_model_filepath must be provided.')
self.hdf5_model_filepath = hdf5_model_filepath
self.name = name
self.quantize = quantize
self.create_model()
def create_model(self):
"""Initializes a model from the protobuf definition.
"""
self.model = model_pb2.Model()
self.model.id = str(uuid.uuid4())
self.model.name = self.name
def serialize(self):
"""serialize method.
Strategy for extracting the weights is adapted from the
load_weights_from_hdf5_group method of the Container class:
see https://github.com/fchollet/keras/blob/master/keras/engine/topology.py#L2505-L2585
"""
hdf5_file = h5py.File(self.hdf5_model_filepath, mode='r')
self.model.keras_version = hdf5_file.attrs['keras_version']
self.model.backend = hdf5_file.attrs['backend']
self.model.model_config = hdf5_file.attrs['model_config']
f = hdf5_file['model_weights']
for layer_name in f.attrs['layer_names']:
g = f[layer_name]
for weight_name in g.attrs['weight_names']:
weight_value = g[weight_name].value
w = self.model.model_weights.add()
w.layer_name = layer_name
w.weight_name = weight_name
w.shape.extend(list(weight_value.shape))
if quantize:
w.type = 'uint8'
quantized, min_val, max_val = quantize_arr(weight_value)
w.data = quantized.astype(np.uint8).tobytes()
w.quantize_min = min_val
w.quantize_max = max_val
else:
w.type = 'float32'
w.data = weight_value.astype(np.float32).tobytes()
hdf5_file.close()
def save(self):
"""Saves as binary protobuf message
"""
pb_model_filepath = os.path.join(os.path.dirname(self.hdf5_model_filepath),
'{}.bin'.format(self.name))
with open(pb_model_filepath, 'wb') as f:
f.write(self.model.SerializeToString())
print('Saved to binary file {}'.format(os.path.abspath(pb_model_filepath)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('hdf5_model_filepath')
parser.add_argument('-n', '--name', type=str, required=False,
help='model name (defaults to filename without extension if not provided)')
parser.add_argument('-q', '--quantize', action='store_true', required=False,
help='quantize weights to 8-bit unsigned int')
args = parser.parse_args()
hdf5_model_filepath = args.hdf5_model_filepath
if args.name is not None:
name = args.name
else:
name = os.path.splitext(os.path.basename(hdf5_model_filepath))[0]
quantize = args.quantize
encoder = Encoder(hdf5_model_filepath, name, quantize)
encoder.serialize()
encoder.save()
Morty Proxy This is a proxified and sanitized view of the page, visit original site.