Skip to content

Commit 3bc0da7

Browse files
SiqiaoWu1993tensorflower-gardener
authored andcommitted
Let RunWithSortedInputsOutputs take graph_name.
PiperOrigin-RevId: 815900754
1 parent 7354cce commit 3bc0da7

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

tensorflow/core/tfrt/graph_executor/graph_executor.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ void CreateSortedNamesAndOriginalIndices(absl::Span<const std::string> names,
560560
}
561561

562562
absl::Status GraphExecutor::RunWithSortedInputsOutputs(
563-
const RunOptions& run_options,
563+
const RunOptions& run_options, absl::string_view graph_name,
564564
absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
565565
absl::Span<const std::string> sorted_input_names,
566566
absl::Span<const tensorflow::DataType> sorted_input_dtypes,
@@ -570,12 +570,11 @@ absl::Status GraphExecutor::RunWithSortedInputsOutputs(
570570
absl::Span<const int> output_original_indices,
571571
std::vector<tensorflow::Tensor>* outputs) {
572572
// Load the client graph.
573-
TF_ASSIGN_OR_RETURN(
574-
LoadedClientGraph & loaded_client_graph,
575-
GetOrCreateLoadedClientGraph(
576-
run_options, sorted_input_names, sorted_input_dtypes,
577-
sorted_output_names, sorted_target_node_names, run_options.work_queue,
578-
/*graph_name=*/{}, inputs));
573+
TF_ASSIGN_OR_RETURN(LoadedClientGraph & loaded_client_graph,
574+
GetOrCreateLoadedClientGraph(
575+
run_options, sorted_input_names, sorted_input_dtypes,
576+
sorted_output_names, sorted_target_node_names,
577+
run_options.work_queue, graph_name, inputs));
579578

580579
// Get a shared_ptr of the executable so that during the current request the
581580
// executable to use is guaranteed to be alive.
@@ -683,9 +682,9 @@ absl::Status GraphExecutor::Run(
683682
std::sort(sorted_target_node_names.begin(), sorted_target_node_names.end());
684683

685684
return RunWithSortedInputsOutputs(
686-
run_options, inputs, sorted_input_names, sorted_input_dtypes,
687-
sorted_output_names, sorted_target_node_names, input_original_indices,
688-
output_original_indices, outputs);
685+
run_options, /*graph_name=*/"", inputs, sorted_input_names,
686+
sorted_input_dtypes, sorted_output_names, sorted_target_node_names,
687+
input_original_indices, output_original_indices, outputs);
689688
}
690689

691690
absl::Status GraphExecutor::Extend(const GraphDef& graph) {

tensorflow/core/tfrt/graph_executor/graph_executor.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,17 @@ class GraphExecutor {
274274
std::vector<tensorflow::Tensor>* outputs);
275275

276276
// Similar as `Run`, but it requires additional input parameters to specify
277-
// the sorted input/output names and the original indices of the
278-
// inputs/outputs. The caller must guarantee that inputs are in the same order
279-
// as of `sorted_input_names`. The sorted input/output names are needed to
280-
// consistently build the key for looking up the `LoadedClientGraph` in the
277+
// the `graph_name`, the sorted input/output names and the original indices of
278+
// the inputs/outputs. The caller must guarantee that inputs are in the same
279+
// order as of `sorted_input_names`. The sorted input/output names are needed
280+
// to consistently build the key for looking up the `LoadedClientGraph` in the
281281
// cache. The original indices are needed to map the results to the original
282-
// inputs/outputs.
282+
// inputs/outputs. The `graph_name` will be used to lookup the compiled graph
283+
// in the cache. It is usually the signature name of the graph. If it is
284+
// empty, a joined name will be constructed from the sorted input/output names
285+
// to lookup the `LoadedClientGraph` in the cache.
283286
absl::Status RunWithSortedInputsOutputs(
284-
const RunOptions& run_options,
287+
const RunOptions& run_options, absl::string_view graph_name,
285288
absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
286289
absl::Span<const std::string> sorted_input_names,
287290
absl::Span<const tensorflow::DataType> sorted_input_dtypes,

0 commit comments

Comments
 (0)