Skip to content

Commit f400030

Browse files
authored
Removed sync arg from wandb.log (#3472)
Fixes #3471 Description: Remove the `sync` argument from the `wandb.log()` calls in the OutputHandler and OptimizerParamsHandler, since the argument was removed. Kept the argument in the initialisation of the handlers, so that existing code that may include the argument keeps working. Check list: - [ ] New tests are added (if a new feature is added) - [ ] New doc strings: description and/or example code are in RST format - [x] Documentation is updated (if required)
1 parent 4f6b37d commit f400030

File tree

2 files changed

+22
-45
lines changed

2 files changed

+22
-45
lines changed

ignite/handlers/wandb_logger.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""WandB logger and its helper handlers."""
22

33
from typing import Any, Callable, List, Optional, Union
4+
from warnings import warn
45

56
from torch.optim import Optimizer
67

@@ -172,8 +173,7 @@ class OutputHandler(BaseOutputHandler):
172173
Default is None, global_step based on attached engine. If provided,
173174
uses function output as global_step. To setup global step from another engine, please use
174175
:meth:`~ignite.handlers.wandb_logger.global_step_from_engine`.
175-
sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
176-
the default value of wandb.log.
176+
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
177177
178178
Examples:
179179
.. code-block:: python
@@ -284,7 +284,8 @@ def __init__(
284284
state_attributes: Optional[List[str]] = None,
285285
):
286286
super().__init__(tag, metric_names, output_transform, global_step_transform, state_attributes)
287-
self.sync = sync
287+
if sync is not None:
288+
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
288289

289290
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
290291
if not isinstance(logger, WandBLogger):
@@ -298,7 +299,7 @@ def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, E
298299
)
299300

300301
metrics = self._setup_output_metrics_state_attrs(engine, log_text=True, key_tuple=False)
301-
logger.log(metrics, step=global_step, sync=self.sync)
302+
logger.log(metrics, step=global_step)
302303

303304

304305
class OptimizerParamsHandler(BaseOptimizerParamsHandler):
@@ -309,8 +310,7 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
309310
as a sequence.
310311
param_name: parameter name
311312
tag: common title for all produced plots. For example, "generator"
312-
sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
313-
the default value of wandb.log.
313+
sync: Deprecated, has no function. Argument is kept here for compatibility with existing code.
314314
315315
Examples:
316316
.. code-block:: python
@@ -346,7 +346,8 @@ def __init__(
346346
self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None
347347
):
348348
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
349-
self.sync = sync
349+
if sync is not None:
350+
warn("The sync argument for the WandBLoggers is no longer used, and may be removed in the future")
350351

351352
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
352353
if not isinstance(logger, WandBLogger):
@@ -358,4 +359,4 @@ def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, E
358359
f"{tag_prefix}{self.param_name}/group_{i}": float(param_group[self.param_name])
359360
for i, param_group in enumerate(self.optimizer.param_groups)
360361
}
361-
logger.log(params, step=global_step, sync=self.sync)
362+
logger.log(params, step=global_step)

tests/ignite/handlers/test_wandb_logger.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ def test_optimizer_params():
3131
mock_engine.state.iteration = 123
3232

3333
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
34-
mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123, sync=None)
34+
mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123)
3535

3636
wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
3737
mock_logger = MagicMock(spec=WandBLogger)
3838
mock_logger.log = MagicMock()
3939

4040
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
41-
mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123, sync=None)
41+
mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123)
4242

4343

4444
def test_output_handler_with_wrong_logger_type():
@@ -62,36 +62,14 @@ def test_output_handler_output_transform():
6262

6363
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
6464

65-
mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=None)
65+
mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123)
6666

6767
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
6868
mock_logger = MagicMock(spec=WandBLogger)
6969
mock_logger.log = MagicMock()
7070

7171
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
72-
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=None)
73-
74-
75-
def test_output_handler_output_transform_sync():
76-
wrapper = OutputHandler("tag", output_transform=lambda x: x, sync=False)
77-
mock_logger = MagicMock(spec=WandBLogger)
78-
mock_logger.log = MagicMock()
79-
80-
mock_engine = MagicMock()
81-
mock_engine.state = State()
82-
mock_engine.state.output = 12345
83-
mock_engine.state.iteration = 123
84-
85-
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
86-
87-
mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=False)
88-
89-
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}, sync=True)
90-
mock_logger = MagicMock(spec=WandBLogger)
91-
mock_logger.log = MagicMock()
92-
93-
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
94-
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=True)
72+
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123)
9573

9674

9775
def test_output_handler_metric_names():
@@ -104,7 +82,7 @@ def test_output_handler_metric_names():
10482
mock_engine.state.iteration = 5
10583

10684
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
107-
mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, step=5, sync=None)
85+
mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, step=5)
10886

10987
wrapper = OutputHandler("tag", metric_names=["a", "c"])
11088
mock_engine = MagicMock()
@@ -115,7 +93,7 @@ def test_output_handler_metric_names():
11593
mock_logger.log = MagicMock()
11694

11795
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
118-
mock_logger.log.assert_called_once_with({"tag/a": 55.56, "tag/c": "Some text"}, step=7, sync=None)
96+
mock_logger.log.assert_called_once_with({"tag/a": 55.56, "tag/c": "Some text"}, step=7)
11997

12098
# all metrics
12199
wrapper = OutputHandler("tag", metric_names="all")
@@ -127,7 +105,7 @@ def test_output_handler_metric_names():
127105
mock_engine.state.iteration = 5
128106

129107
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
130-
mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45}, step=5, sync=None)
108+
mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45}, step=5)
131109

132110
# log a torch vector
133111
wrapper = OutputHandler("tag", metric_names="all")
@@ -139,7 +117,7 @@ def test_output_handler_metric_names():
139117
mock_engine.state.iteration = 5
140118

141119
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
142-
mock_logger.log.assert_called_once_with({f"tag/a/{i}": vector[i].item() for i in range(5)}, step=5, sync=None)
120+
mock_logger.log.assert_called_once_with({f"tag/a/{i}": vector[i].item() for i in range(5)}, step=5)
143121

144122
wrapper = OutputHandler("tag", metric_names=["a"])
145123
mock_engine = MagicMock()
@@ -151,7 +129,7 @@ def test_output_handler_metric_names():
151129
mock_logger.log = MagicMock()
152130

153131
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
154-
mock_logger.log.assert_called_once_with({f"tag/a/{i}": v for i, v in enumerate(data)}, step=7, sync=None)
132+
mock_logger.log.assert_called_once_with({f"tag/a/{i}": v for i, v in enumerate(data)}, step=7)
155133

156134
wrapper = OutputHandler("tag", metric_names="all")
157135
mock_engine = MagicMock()
@@ -179,7 +157,6 @@ def test_output_handler_metric_names():
179157
"tag/c/2/e": 32.1,
180158
},
181159
step=7,
182-
sync=None,
183160
)
184161

185162

@@ -195,7 +172,7 @@ def test_output_handler_both():
195172

196173
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
197174

198-
mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5, sync=None)
175+
mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5)
199176

200177

201178
def test_output_handler_with_wrong_global_step_transform_output():
@@ -229,7 +206,7 @@ def global_step_transform(*args, **kwargs):
229206
mock_engine.state.output = 12345
230207

231208
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
232-
mock_logger.log.assert_called_once_with({"tag/loss": 12345}, step=10, sync=None)
209+
mock_logger.log.assert_called_once_with({"tag/loss": 12345}, step=10)
233210

234211

235212
def test_output_handler_with_global_step_from_engine():
@@ -254,7 +231,7 @@ def test_output_handler_with_global_step_from_engine():
254231

255232
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
256233
mock_logger.log.assert_called_once_with(
257-
{"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None
234+
{"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch
258235
)
259236

260237
mock_another_engine.state.epoch = 11
@@ -263,7 +240,7 @@ def test_output_handler_with_global_step_from_engine():
263240
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
264241
assert mock_logger.log.call_count == 2
265242
mock_logger.log.assert_has_calls(
266-
[call({"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None)]
243+
[call({"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)]
267244
)
268245

269246

@@ -291,7 +268,6 @@ def test_output_handler_state_attrs():
291268
"tag/delta": "Some Text",
292269
},
293270
step=5,
294-
sync=None,
295271
)
296272

297273

0 commit comments

Comments
 (0)