Skip to content

Commit 1b0ecc2

Browse files
authored
Merge branch 'master' into fix-deprecated-statement
2 parents 2ae8f3a + 6f8ad2a commit 1b0ecc2

File tree

18 files changed

+119
-64
lines changed

18 files changed

+119
-64
lines changed

.github/workflows/gpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ jobs:
124124
uses: nick-fields/[email protected]
125125
with:
126126
max_attempts: 5
127-
timeout_minutes: 25
127+
timeout_minutes: 45
128128
shell: bash
129129
command: docker exec -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2'
130130
new_command_on_retry: docker exec -e USE_LAST_FAILED=1 -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2'

ignite/contrib/engines/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def setup_common_training_handlers(
7878
lr_scheduler: learning rate scheduler
7979
as native torch LRScheduler or ignite's parameter scheduler.
8080
with_gpu_stats: if True, :class:`~ignite.metrics.GpuInfo` is attached to the
81-
trainer. This requires `pynvml` package to be installed.
81+
trainer. This requires `pynvml<12` package to be installed.
8282
output_names: list of names associated with `update_function` output dictionary.
8383
with_pbars: if True, two progress bars on epochs and optionally on iterations are attached.
8484
Default, True.

ignite/engine/engine.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
140140
self._process_function = process_function
141141
self.last_event_name: Optional[Events] = None
142142
self.should_terminate = False
143+
self.skip_completed_after_termination = False
143144
self.should_terminate_single_epoch = False
144145
self.should_interrupt = False
145146
self.state = State()
@@ -538,7 +539,7 @@ def call_interrupt():
538539
self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.")
539540
self.should_interrupt = True
540541

541-
def terminate(self) -> None:
542+
def terminate(self, skip_completed: bool = False) -> None:
542543
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
543544
terminated after the event on which ``terminate`` method was called. The following events are triggered:
544545
@@ -547,6 +548,9 @@ def terminate(self) -> None:
547548
- :attr:`~ignite.engine.events.Events.TERMINATE`
548549
- :attr:`~ignite.engine.events.Events.COMPLETED`
549550
551+
Args:
552+
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
553+
:attr:`~ignite.engine.events.Events.TERMINATE`. Default is False.
550554
551555
Examples:
552556
.. testcode::
@@ -617,9 +621,12 @@ def terminate():
617621
.. versionchanged:: 0.4.10
618622
Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669
619623
624+
.. versionchanged:: 0.5.2
625+
Added `skip_completed` flag
620626
"""
621627
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
622628
self.should_terminate = True
629+
self.skip_completed_after_termination = skip_completed
623630

624631
def terminate_epoch(self) -> None:
625632
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
@@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
9931000
time_taken = time.time() - start_time
9941001
# time is available for handlers but must be updated after fire
9951002
self.state.times[Events.COMPLETED.name] = time_taken
996-
handlers_start_time = time.time()
997-
self._fire_event(Events.COMPLETED)
998-
time_taken += time.time() - handlers_start_time
999-
# update time wrt handlers
1000-
self.state.times[Events.COMPLETED.name] = time_taken
1003+
1004+
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1005+
if not (self.should_terminate and self.skip_completed_after_termination):
1006+
handlers_start_time = time.time()
1007+
self._fire_event(Events.COMPLETED)
1008+
time_taken += time.time() - handlers_start_time
1009+
# update time wrt handlers
1010+
self.state.times[Events.COMPLETED.name] = time_taken
1011+
10011012
hours, mins, secs = _to_hours_mins_secs(time_taken)
1002-
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
1013+
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
10031014

10041015
except BaseException as e:
10051016
self._dataloader_iter = None
@@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State:
11741185
time_taken = time.time() - start_time
11751186
# time is available for handlers but must be updated after fire
11761187
self.state.times[Events.COMPLETED.name] = time_taken
1177-
handlers_start_time = time.time()
1178-
self._fire_event(Events.COMPLETED)
1179-
time_taken += time.time() - handlers_start_time
1180-
# update time wrt handlers
1181-
self.state.times[Events.COMPLETED.name] = time_taken
1188+
1189+
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1190+
if not (self.should_terminate and self.skip_completed_after_termination):
1191+
handlers_start_time = time.time()
1192+
self._fire_event(Events.COMPLETED)
1193+
time_taken += time.time() - handlers_start_time
1194+
# update time wrt handlers
1195+
self.state.times[Events.COMPLETED.name] = time_taken
1196+
11821197
hours, mins, secs = _to_hours_mins_secs(time_taken)
1183-
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
1198+
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
11841199

11851200
except BaseException as e:
11861201
self._dataloader_iter = None

ignite/engine/events.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,36 +259,47 @@ class Events(EventEnum):
259259
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
260260
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
261261
:meth:`~ignite.engine.engine.Engine.terminate()` call.
262+
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
263+
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
262264
263265
- TERMINATE : triggered when the run is about to end completely,
264266
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
265267
266-
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
267-
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
268-
- COMPLETED : triggered when engine's run is completed
268+
- COMPLETED : triggered when engine's run is completed or terminated with
269+
:meth:`~ignite.engine.engine.Engine.terminate()`, unless the flag
270+
`skip_completed` is set to True.
269271
270272
The table below illustrates which events are triggered when various termination methods are called.
271273
272274
.. list-table::
273-
:widths: 24 25 33 18
275+
:widths: 35 38 28 20 20
274276
:header-rows: 1
275277
276278
* - Method
277-
- EVENT_COMPLETED
278279
- TERMINATE_SINGLE_EPOCH
280+
- EPOCH_COMPLETED
279281
- TERMINATE
282+
- COMPLETED
280283
* - no termination
281-
- ✔
282284
- ✗
285+
- ✔
283286
- ✗
287+
- ✔
284288
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()`
285289
- ✔
286290
- ✔
287291
- ✗
292+
- ✔
288293
* - :meth:`~ignite.engine.engine.Engine.terminate()`
289294
- ✗
290295
- ✔
291296
- ✔
297+
- ✔
298+
* - :meth:`~ignite.engine.engine.Engine.terminate()` with `skip_completed=True`
299+
- ✗
300+
- ✔
301+
- ✔
302+
- ✗
292303
293304
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine:
294305
@@ -357,7 +368,7 @@ class CustomEvents(EventEnum):
357368
STARTED = "started"
358369
"""triggered when engine's run is started."""
359370
COMPLETED = "completed"
360-
"""triggered when engine's run is completed"""
371+
"""triggered when engine's run is completed, or after receiving terminate() call."""
361372

362373
ITERATION_STARTED = "iteration_started"
363374
"""triggered when an iteration is started."""

ignite/metrics/clustering/calinski_harabasz_score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float:
1212
from sklearn.metrics import calinski_harabasz_score
1313

14-
np_features = features.numpy()
15-
np_labels = labels.numpy()
14+
np_features = features.cpu().numpy()
15+
np_labels = labels.cpu().numpy()
1616
score = calinski_harabasz_score(np_features, np_labels)
1717
return score
1818

ignite/metrics/clustering/davies_bouldin_score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float:
1212
from sklearn.metrics import davies_bouldin_score
1313

14-
np_features = features.numpy()
15-
np_labels = labels.numpy()
14+
np_features = features.cpu().numpy()
15+
np_labels = labels.cpu().numpy()
1616
score = davies_bouldin_score(np_features, np_labels)
1717
return score
1818

ignite/metrics/clustering/silhouette_score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
def _silhouette_score(self, features: Tensor, labels: Tensor) -> float:
112112
from sklearn.metrics import silhouette_score
113113

114-
np_features = features.numpy()
115-
np_labels = labels.numpy()
114+
np_features = features.cpu().numpy()
115+
np_labels = labels.cpu().numpy()
116116
score = silhouette_score(np_features, np_labels, **self._silhouette_kwargs)
117117
return score

ignite/metrics/gpu_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class GpuInfo(Metric):
1212
"""Provides GPU information: a) used memory percentage, b) gpu utilization percentage values as Metric
13-
on each iterations.
13+
on each iterations. This metric requires `pynvml <https://pypi.org/project/pynvml/>`_ package of version `<12`.
1414
1515
.. Note ::
1616
@@ -39,7 +39,7 @@ def __init__(self) -> None:
3939
except ImportError:
4040
raise ModuleNotFoundError(
4141
"This contrib module requires pynvml to be installed. "
42-
"Please install it with command: \n pip install pynvml"
42+
"Please install it with command: \n pip install 'pynvml<12'"
4343
)
4444
# Let's check available devices
4545
if not torch.cuda.is_available():

ignite/metrics/regression/kendall_correlation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]:
1616
raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.")
1717

1818
def _tau(predictions: Tensor, targets: Tensor) -> float:
19-
np_preds = predictions.flatten().numpy()
20-
np_targets = targets.flatten().numpy()
19+
np_preds = predictions.flatten().cpu().numpy()
20+
np_targets = targets.flatten().cpu().numpy()
2121
r = kendalltau(np_preds, np_targets, variant=variant).statistic
2222
return r
2323

ignite/metrics/regression/spearman_correlation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
def _spearman_r(predictions: Tensor, targets: Tensor) -> float:
1313
from scipy.stats import spearmanr
1414

15-
np_preds = predictions.flatten().numpy()
16-
np_targets = targets.flatten().numpy()
15+
np_preds = predictions.flatten().cpu().numpy()
16+
np_targets = targets.flatten().cpu().numpy()
1717
r = spearmanr(np_preds, np_targets).statistic
1818
return r
1919

0 commit comments

Comments
 (0)