@@ -437,6 +437,23 @@ def get(
437437 except exceptions .NotFound :
438438 return None
439439
440+ def _initialize_experiment_run (
441+ self ,
442+ node : Union [context .Context , execution .Execution ],
443+ experiment : Optional [experiment_resources .Experiment ] = None ,
444+ ):
445+ self ._experiment = experiment
446+ self ._run_name = node .display_name
447+ self ._metadata_node = node
448+ self ._largest_step = None
449+
450+ if self ._is_legacy_experiment_run ():
451+ self ._metadata_metric_artifact = self ._v1_get_metric_artifact ()
452+ self ._backing_tensorboard_run = None
453+ else :
454+ self ._metadata_metric_artifact = None
455+ self ._backing_tensorboard_run = self ._lookup_tensorboard_run_artifact ()
456+
440457 @classmethod
441458 def list (
442459 cls ,
@@ -495,33 +512,17 @@ def list(
495512
496513 run_executions = execution .Execution .list (filter = filter_str , ** metadata_args )
497514
498- def _initialize_experiment_run (context : context .Context ) -> ExperimentRun :
515+ def _create_experiment_run (context : context .Context ) -> ExperimentRun :
499516 this_experiment_run = cls .__new__ (cls )
500- this_experiment_run ._experiment = experiment
501- this_experiment_run ._run_name = context .display_name
502- this_experiment_run ._metadata_node = context
503-
504- with experiment_resources ._SetLoggerLevel (resource ):
505- tb_run = this_experiment_run ._lookup_tensorboard_run_artifact ()
506- if tb_run :
507- this_experiment_run ._backing_tensorboard_run = tb_run
508- else :
509- this_experiment_run ._backing_tensorboard_run = None
510-
511- this_experiment_run ._largest_step = None
517+ this_experiment_run ._initialize_experiment_run (context , experiment )
512518
513519 return this_experiment_run
514520
515- def _initialize_v1_experiment_run (
521+ def _create_v1_experiment_run (
516522 execution : execution .Execution ,
517523 ) -> ExperimentRun :
518524 this_experiment_run = cls .__new__ (cls )
519- this_experiment_run ._experiment = experiment
520- this_experiment_run ._run_name = execution .display_name
521- this_experiment_run ._metadata_node = execution
522- this_experiment_run ._metadata_metric_artifact = (
523- this_experiment_run ._v1_get_metric_artifact ()
524- )
525+ this_experiment_run ._initialize_experiment_run (execution , experiment )
525526
526527 return this_experiment_run
527528
@@ -530,13 +531,13 @@ def _initialize_v1_experiment_run(
530531 max_workers = max ([len (run_contexts ), len (run_executions )])
531532 ) as executor :
532533 submissions = [
533- executor .submit (_initialize_experiment_run , context )
534+ executor .submit (_create_experiment_run , context )
534535 for context in run_contexts
535536 ]
536537 experiment_runs = [submission .result () for submission in submissions ]
537538
538539 submissions = [
539- executor .submit (_initialize_v1_experiment_run , execution )
540+ executor .submit (_create_v1_experiment_run , execution )
540541 for execution in run_executions
541542 ]
542543
@@ -560,30 +561,20 @@ def _query_experiment_row(
560561 Experiment run row that represents this run.
561562 """
562563 this_experiment_run = cls .__new__ (cls )
563- this_experiment_run ._metadata_node = node
564+ this_experiment_run ._initialize_experiment_run ( node )
564565
565566 row = experiment_resources ._ExperimentRow (
566567 experiment_run_type = node .schema_title ,
567568 name = node .display_name ,
568569 )
569570
570- if isinstance (node , context .Context ):
571- this_experiment_run ._backing_tensorboard_run = (
572- this_experiment_run ._lookup_tensorboard_run_artifact ()
573- )
574- row .params = node .metadata [constants ._PARAM_KEY ]
575- row .metrics = node .metadata [constants ._METRIC_KEY ]
576- row .time_series_metrics = (
577- this_experiment_run ._get_latest_time_series_metric_columns ()
578- )
579- row .state = node .metadata [constants ._STATE_KEY ]
580- else :
581- this_experiment_run ._metadata_metric_artifact = (
582- this_experiment_run ._v1_get_metric_artifact ()
583- )
584- row .params = node .metadata
585- row .metrics = this_experiment_run ._metadata_metric_artifact .metadata
586- row .state = node .state .name
571+ row .params = this_experiment_run .get_params ()
572+ row .metrics = this_experiment_run .get_metrics ()
573+ row .state = this_experiment_run .get_state ()
574+ row .time_series_metrics = (
575+ this_experiment_run ._get_latest_time_series_metric_columns ()
576+ )
577+
587578 return row
588579
589580 def _get_logged_pipeline_runs (self ) -> List [context .Context ]:
@@ -659,7 +650,7 @@ def log(
659650
660651 @staticmethod
661652 def _validate_run_id (run_id : str ):
662- """Validates the run id
653+ """Validates the run id.
663654
664655 Args:
665656 run_id(str): Required. The run id to validate.
@@ -1455,6 +1446,13 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
14551446 else :
14561447 return self ._metadata_node .metadata [constants ._METRIC_KEY ]
14571448
1449+ def get_state (self ) -> gca_execution .Execution .State :
1450+ """The state of this run."""
1451+ if self ._is_legacy_experiment_run ():
1452+ return self ._metadata_node .state .name
1453+ else :
1454+ return self ._metadata_node .metadata [constants ._STATE_KEY ]
1455+
14581456 @_v1_not_supported
14591457 def get_classification_metrics (self ) -> List [Dict [str , Union [str , List ]]]:
14601458 """Get all the classification metrics logged to this run.
0 commit comments