Skip to content

Commit be0ebe5

Browse files
committed
Fix filename tensor name in SaverDef
1 parent 2744e6c commit be0ebe5

File tree

4 files changed

+70
-56
lines changed

4 files changed

+70
-56
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -461,11 +461,12 @@ synchronized SaverDef saverDef() {
461461
// regenerate SaverDef without mutating. The names mirror
462462
// the python implementation for compatibility.
463463
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
464-
saverDef = SaverDef.newBuilder()
465-
.setFilenameTensorName("save/filename")
466-
.setSaveTensorName("save/control_dependency")
467-
.setRestoreOpName("save/restore_all")
468-
.build();
464+
saverDef =
465+
SaverDef.newBuilder()
466+
.setFilenameTensorName("save/filename:0")
467+
.setSaveTensorName("save/control_dependency")
468+
.setRestoreOpName("save/restore_all")
469+
.build();
469470
}
470471
}
471472
return saverDef;
@@ -812,36 +813,35 @@ private static SaverDef addVariableSaver(Graph graph) {
812813
}
813814
}
814815

815-
// FIXME Need an easier way to initialize an NdArray from a list
816-
String[] tmp = new String[varNames.size()];
817-
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
818-
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
819-
820-
Placeholder<TString> saveFilename = tf.withName("filename").placeholder(TString.class);
821-
Save saveVariables = tf.train.save(
822-
saveFilename,
823-
varNamesTensor,
824-
varSlices,
825-
varOutputs
826-
);
827-
Identity<TString> id = tf.withControlDependencies(Arrays.asList(saveFilename,saveVariables))
828-
.withName("control_dependency").identity(saveFilename);
829-
Restore restoreVariables = tf.train.restore(
830-
saveFilename,
831-
varNamesTensor,
832-
varSlices,
833-
varTypes
834-
);
835-
List<Op> restoreOps = new ArrayList<>(varOutputs.size());
836-
for (int i = 0; i < varOutputs.size(); ++i) {
837-
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
816+
Placeholder<TString> filename = tf.withName("filename").placeholder(TString.class);
817+
Identity<TString> save = null;
818+
NoOp restore = null;
819+
820+
if (varNames.isEmpty()) {
821+
save = tf.withName("empty_save").identity(filename);
822+
restore = tf.withName("restore_all").noOp();
823+
} else {
824+
String[] tmp = new String[varNames.size()];
825+
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
826+
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
827+
Save saveVars = tf.train.save(filename, varNamesTensor, varSlices, varOutputs);
828+
List<Op> saveDeps = Arrays.asList(filename, saveVars);
829+
Restore restoreVars = tf.train.restore(filename, varNamesTensor, varSlices, varTypes);
830+
List<Op> restoreDeps = new ArrayList<>(varOutputs.size());
831+
for (int i = 0; i < varOutputs.size(); ++i) {
832+
restoreDeps.add(tf.assign(varOutputs.get(i), (Operand) restoreVars.tensors().get(i)));
833+
}
834+
save = tf.withControlDependencies(saveDeps).withName("control_dependency").identity(filename);
835+
restore = tf.withControlDependencies(restoreDeps).withName("restore_all").noOp();
838836
}
839-
NoOp restoreAll = tf.withControlDependencies(restoreOps).withName("restore_all").noOp();
840837

838+
// 'Filename' must be the name of a tensor (i.e. with output index)
839+
// 'Save' must be an operation name, even if the field name is confusing (see SaverDef doc)
840+
// 'Restore' must be an operation name
841841
return SaverDef.newBuilder()
842-
.setFilenameTensorName(saveFilename.op().name())
843-
.setSaveTensorName(id.op().name())
844-
.setRestoreOpName(restoreAll.op().name())
842+
.setFilenameTensorName(filename.output().name())
843+
.setSaveTensorName(save.op().name())
844+
.setRestoreOpName(restore.op().name())
845845
.build();
846846
}
847847

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ public int index() {
3838
return index;
3939
}
4040

41+
/** Returns the full name of this Output (a.k.a. tensor name) */
42+
public String name() {
43+
return op().name() + ":" + index;
44+
}
45+
4146
/** Returns the DataType of the tensor referred to by this Output. */
4247
@SuppressWarnings("unchecked")
4348
public DataType dataType() {
@@ -48,7 +53,7 @@ public DataType dataType() {
4853
@SuppressWarnings("unchecked")
4954
@Override
5055
public Class<T> type() {
51-
return (Class<T>)TensorTypeRegistry.find(dataType()).type();
56+
return (Class<T>) TensorTypeRegistry.find(dataType()).type();
5257
}
5358

5459
/**
@@ -63,7 +68,10 @@ public Class<T> type() {
6368
public <U extends TType> Output<U> expect(Class<U> type) {
6469
if (type != type()) {
6570
throw new IllegalArgumentException(
66-
"Cannot cast from output of " + this.type().getSimpleName() + " to output of " + type.getSimpleName());
71+
"Cannot cast from output of "
72+
+ this.type().getSimpleName()
73+
+ " to output of "
74+
+ type.getSimpleName());
6775
}
6876
return ((Output<U>) this);
6977
}
@@ -80,17 +88,16 @@ public <U extends TType> Output<U> expect(Class<U> type) {
8088
*
8189
* @return tensor
8290
* @throws IllegalStateException if this output results from a graph
83-
* @throws ClassCastException if the type of the tensor and this output are unexpectedly incompatible
91+
* @throws ClassCastException if the type of the tensor and this output are unexpectedly
92+
* incompatible
8493
* @see EagerSession
8594
*/
8695
@SuppressWarnings("unchecked")
8796
public T asTensor() {
88-
return (T)operation.tensor(index);
97+
return (T) operation.tensor(index);
8998
}
9099

91-
/**
92-
* Returns the (possibly partially known) shape of the tensor referred to by this output.
93-
*/
100+
/** Returns the (possibly partially known) shape of the tensor referred to by this output. */
94101
@Override
95102
public Shape shape() {
96103
return operation.shape(index);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
/*
2-
* Copyright 2020 The TensorFlow Authors. All rights reserved.
3-
*
4-
* Licensed under the Apache License, Version 2.0 (the "License");
5-
* you may not use this file except in compliance with the License.
6-
* You may obtain a copy of the License at
7-
*
8-
* http://www.apache.org/licenses/LICENSE-2.0
9-
*
10-
* Unless required by applicable law or agreed to in writing, software
11-
* distributed under the License is distributed on an "AS IS" BASIS,
12-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
* See the License for the specific language governing permissions and
14-
* limitations under the License.
15-
*/
1+
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
1616
package org.tensorflow;
1717

1818
import java.util.HashMap;
@@ -171,7 +171,7 @@ public Set<String> outputNames() {
171171

172172
@Override
173173
public String toString() {
174-
StringBuilder strBuilder = new StringBuilder("Signature for \"" + key +"\":\n");
174+
StringBuilder strBuilder = new StringBuilder("Signature for \"" + key + "\":\n");
175175
String methodName = methodName();
176176
if (methodName != null && !methodName.isEmpty()) {
177177
strBuilder.append("\tMethod: \"").append(methodName).append("\"\n");

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.tensorflow.proto.framework.RunOptions;
4545
import org.tensorflow.proto.framework.SignatureDef;
4646
import org.tensorflow.proto.framework.TensorInfo;
47+
import org.tensorflow.proto.util.SaverDef;
4748
import org.tensorflow.types.TFloat32;
4849

4950
/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */
@@ -123,7 +124,13 @@ public void exportFunctionWithVariables() throws IOException {
123124
try (SavedModelBundle savedModel =
124125
SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) {
125126
assertNotNull(savedModel.metaGraphDef());
126-
assertNotNull(savedModel.metaGraphDef().getSaverDef());
127+
128+
SaverDef saverDef = savedModel.metaGraphDef().getSaverDef();
129+
assertNotNull(saverDef);
130+
assertEquals("save/filename:0", saverDef.getFilenameTensorName());
131+
assertEquals("save/control_dependency", saverDef.getSaveTensorName());
132+
assertEquals("save/restore_all", saverDef.getRestoreOpName());
133+
127134
assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
128135
assertEquals(Signature.DEFAULT_KEY,
129136
savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next());

0 commit comments

Comments
 (0)