Skip to content

Commit 6c2bc1f

Browse files
authored
Add live_info_terminal, closes #436 (#441)
* Add live_info_terminal, closes #436 * Add time * rename var * Add doc-string
1 parent bdcbbf3 commit 6c2bc1f

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

adaptive/runner.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,55 @@ def live_info(self, *, update_interval: float = 0.1) -> None:
776776
"""
777777
return live_info(self, update_interval=update_interval)
778778

779+
def live_info_terminal(
780+
self, *, update_interval: float = 0.5, overwrite_previous: bool = True
781+
) -> asyncio.Task:
782+
"""
783+
Display live information about the runner in the terminal.
784+
785+
This function provides a live update of the runner's status in the terminal.
786+
The update can either overwrite the previous status or be printed on a new line.
787+
788+
Parameters
789+
----------
790+
update_interval : float, optional
791+
The time interval (in seconds) at which the runner's status is updated in the terminal.
792+
Default is 0.5 seconds.
793+
overwrite_previous : bool, optional
794+
If True, each update will overwrite the previous status in the terminal.
795+
If False, each update will be printed on a new line.
796+
Default is True.
797+
798+
Returns
799+
-------
800+
asyncio.Task
801+
The asynchronous task responsible for updating the runner's status in the terminal.
802+
803+
Examples
804+
--------
805+
>>> runner = AsyncRunner(...)
806+
>>> runner.live_info_terminal(update_interval=1.0, overwrite_previous=False)
807+
808+
Notes
809+
-----
810+
This function uses ANSI escape sequences to control the terminal's cursor position.
811+
It might not work as expected on all terminal emulators.
812+
"""
813+
814+
async def _update(runner: AsyncRunner) -> None:
815+
try:
816+
while not runner.task.done():
817+
if overwrite_previous:
818+
# Clear the terminal
819+
print("\033[H\033[J", end="")
820+
print(_info_text(runner, separator="\t"))
821+
await asyncio.sleep(update_interval)
822+
823+
except asyncio.CancelledError:
824+
print("Live info display cancelled.")
825+
826+
return self.ioloop.create_task(_update(self))
827+
779828
async def _run(self) -> None:
780829
first_completed = asyncio.FIRST_COMPLETED
781830

@@ -855,6 +904,43 @@ async def _saver():
855904
return self.saving_task
856905

857906

907+
def _info_text(runner, separator: str = "\n"):
908+
status = runner.status()
909+
910+
color_map = {
911+
"cancelled": "\033[33m", # Yellow
912+
"failed": "\033[31m", # Red
913+
"running": "\033[34m", # Blue
914+
"finished": "\033[32m", # Green
915+
}
916+
917+
overhead = runner.overhead()
918+
if overhead < 50:
919+
overhead_color = "\033[32m" # Green
920+
else:
921+
overhead_color = "\033[31m" # Red
922+
923+
info = [
924+
("time", str(datetime.now())),
925+
("status", f"{color_map[status]}{status}\033[0m"),
926+
("elapsed time", str(timedelta(seconds=runner.elapsed_time()))),
927+
("overhead", f"{overhead_color}{overhead:.2f}%\033[0m"),
928+
]
929+
930+
with suppress(Exception):
931+
info.append(("# of points", runner.learner.npoints))
932+
933+
with suppress(Exception):
934+
info.append(("# of samples", runner.learner.nsamples))
935+
936+
with suppress(Exception):
937+
info.append(("latest loss", f'{runner.learner._cache["loss"]:.3f}'))
938+
939+
width = 30
940+
formatted_info = [f"{k}: {v}".ljust(width) for i, (k, v) in enumerate(info)]
941+
return separator.join(formatted_info)
942+
943+
858944
# Default runner
859945
Runner = AsyncRunner
860946

0 commit comments

Comments
 (0)