11"""Converters for OSS Vizier's protos from/to PyVizier's classes."""
2- import datetime
32import logging
3+ from datetime import timezone
44from typing import List , Optional , Sequence , Tuple , Union
55
66from google .protobuf import duration_pb2
7+ from google .protobuf import struct_pb2
8+ from google .protobuf import timestamp_pb2
79from google .cloud .aiplatform .compat .types import study as study_pb2
810from google .cloud .aiplatform .vizier .pyvizier import ScaleType
911from google .cloud .aiplatform .vizier .pyvizier import ParameterType
@@ -80,8 +82,8 @@ def _set_default_value(
8082 default_value : Union [float , int , str ],
8183 ):
8284 """Sets the protos' default_value field."""
83- which_pv_spec = proto .WhichOneof ("parameter_value_spec" )
84- getattr (proto , which_pv_spec ).default_value . value = default_value
85+ which_pv_spec = proto ._pb . WhichOneof ("parameter_value_spec" )
86+ getattr (proto , which_pv_spec ).default_value = default_value
8587
8688 @classmethod
8789 def _matching_parent_values (
@@ -280,17 +282,16 @@ def to_proto(
280282 cls , parameter_value : ParameterValue , name : str
281283 ) -> study_pb2 .Trial .Parameter :
282284 """Returns Parameter Proto."""
283- proto = study_pb2 .Trial .Parameter (parameter_id = name )
284-
285285 if isinstance (parameter_value .value , int ):
286- proto . value . number_value = parameter_value .value
286+ value = struct_pb2 . Value ( number_value = parameter_value .value )
287287 elif isinstance (parameter_value .value , bool ):
288- proto . value . bool_value = parameter_value .value
288+ value = struct_pb2 . Value ( bool_value = parameter_value .value )
289289 elif isinstance (parameter_value .value , float ):
290- proto . value . number_value = parameter_value .value
290+ value = struct_pb2 . Value ( number_value = parameter_value .value )
291291 elif isinstance (parameter_value .value , str ):
292- proto . value . string_value = parameter_value .value
292+ value = struct_pb2 . Value ( string_value = parameter_value .value )
293293
294+ proto = study_pb2 .Trial .Parameter (parameter_id = name , value = value )
294295 return proto
295296
296297
@@ -340,18 +341,19 @@ def from_proto(cls, proto: study_pb2.Measurement) -> Measurement:
340341 @classmethod
341342 def to_proto (cls , measurement : Measurement ) -> study_pb2 .Measurement :
342343 """Converts to Measurement proto."""
343- proto = study_pb2 .Measurement ()
344+ int_seconds = int (measurement .elapsed_secs )
345+ proto = study_pb2 .Measurement (
346+ elapsed_duration = duration_pb2 .Duration (
347+ seconds = int_seconds ,
348+ nanos = int (1e9 * (measurement .elapsed_secs - int_seconds )),
349+ )
350+ )
344351 for name , metric in measurement .metrics .items ():
345352 proto .metrics .append (
346353 study_pb2 .Measurement .Metric (metric_id = name , value = metric .value )
347354 )
348355
349356 proto .step_count = measurement .steps
350- int_seconds = int (measurement .elapsed_secs )
351- proto .elapsed_duration = duration_pb2 .Duration (
352- seconds = int_seconds ,
353- nanos = int (1e9 * (measurement .elapsed_secs - int_seconds )),
354- )
355357 return proto
356358
357359
@@ -426,8 +428,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:
426428 infeasibility_reason = None
427429 if proto .state == study_pb2 .Trial .State .SUCCEEDED :
428430 if proto .end_time :
429- completion_ts = proto .end_time .nanosecond / 1e9
430- completion_time = datetime .datetime .fromtimestamp (completion_ts )
431+ completion_time = (
432+ proto .end_time .timestamp_pb ()
433+ .ToDatetime ()
434+ .replace (tzinfo = timezone .utc )
435+ )
431436 elif proto .state == study_pb2 .Trial .State .INFEASIBLE :
432437 infeasibility_reason = proto .infeasible_reason
433438
@@ -437,8 +442,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:
437442
438443 creation_time = None
439444 if proto .start_time :
440- creation_ts = proto .start_time .nanosecond / 1e9
441- creation_time = datetime .datetime .fromtimestamp (creation_ts )
445+ creation_time = (
446+ proto .start_time .timestamp_pb ()
447+ .ToDatetime ()
448+ .replace (tzinfo = timezone .utc )
449+ )
442450 return Trial (
443451 id = int (proto .name .split ("/" )[- 1 ]),
444452 description = proto .name ,
@@ -481,22 +489,26 @@ def to_proto(cls, pytrial: Trial) -> study_pb2.Trial:
481489
482490 # pytrial always adds an empty metric. Ideally, we should remove it if the
483491 # metric does not exist in the study config.
492+ # setattr() is required here as `proto.final_measurement.CopyFrom`
493+ # raises AttributeErrors when setting the field on the pb2 compat types.
484494 if pytrial .final_measurement is not None :
485- proto .final_measurement .CopyFrom (
486- MeasurementConverter .to_proto (pytrial .final_measurement )
495+ setattr (
496+ proto ,
497+ "final_measurement" ,
498+ MeasurementConverter .to_proto (pytrial .final_measurement ),
487499 )
488500
489501 for measurement in pytrial .measurements :
490502 proto .measurements .append (MeasurementConverter .to_proto (measurement ))
491503
492504 if pytrial .creation_time is not None :
493- creation_secs = datetime . datetime . timestamp ( pytrial . creation_time )
494- proto . start_time .seconds = int ( creation_secs )
495- proto . start_time . nanos = int ( 1e9 * ( creation_secs - int ( creation_secs )) )
505+ start_time = timestamp_pb2 . Timestamp ( )
506+ start_time .FromDatetime ( pytrial . creation_time )
507+ setattr ( proto , " start_time" , start_time )
496508 if pytrial .completion_time is not None :
497- completion_secs = datetime . datetime . timestamp ( pytrial . completion_time )
498- proto . end_time .seconds = int ( completion_secs )
499- proto . end_time . nanos = int ( 1e9 * ( completion_secs - int ( completion_secs )) )
509+ end_time = timestamp_pb2 . Timestamp ( )
510+ end_time .FromDatetime ( pytrial . completion_time )
511+ setattr ( proto , " end_time" , end_time )
500512 if pytrial .infeasibility_reason is not None :
501513 proto .infeasible_reason = pytrial .infeasibility_reason
502514 return proto
0 commit comments