-
Notifications
You must be signed in to change notification settings - Fork 54
[WIP] Integration with DeepLabCut 3.0 - PyTorch Engine #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit corrects the parameter names and order and removes unsupported arguments from `benchmark_videos` call: - device - precision - draw_keypoint_names - get_sys_info
- pyproject.toml is made PEP 517 / 518 / 621–compliant, which makes it easy to inherit dependencies - it is more tool-agnostic (no poetry required & facilitates uv packaging). - Python versions 3.10 up to 3.12 - tensorflow installation is restricted to python < 3.12 for windows
Co-authored-by: Copilot <[email protected]>
- Remove upper bound for tensorflow version is python version >=3.11 - add pip as required dependency
remove tkinter as required dependency.
(remove conditional tf dependency install for for windows and python >= 3.11, since never reached)
…liant *Update pyproject.toml and CI* - pyproject.toml is made PEP 517 / 518 / 621–compliant, which makes it easy to inherit dependencies - it is more tool-agnostic (no poetry required & facilitates uv packaging). - Python versions 3.10 up to 3.12 - Tensorflow installation is restricted to python < 3.11 for windows - CI is updated to test more Python versions and to use uv instead of poetry.
Added a more robust downloading function in `utils.py` that handles errors more gracefully. The download is skipped if the file already exists. In previous implementation, the benchmarking video was downloaded to the home directory, which in windows could result in permission issues when reading the file. This is now changed to the check_install directory. OpenCV was silently failing to read the video file, resulting in a frame count of zero. Now a ValueError is raised when the video cannot be read.
- removed poetry instructions - added optional instructions for installation with uv - added notes on tensorflow support for windows
deruyter92
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looked trough most of the changes and addressed the merge conflicts. I think it would be great if we could merge this WIP as soon as possible, so that we can cleanly contribute aditional PRs
What I have worked on:
- Merged with main (accepting combination of both, but mostly dlclive3 branch for updated code)
- Complete update of pyproject.toml
- more tool-agnostic (no poetry required)
- uv-friendly, so the user can choose their own preferred packaging method
- updated missing dependencies
- Python versions 3.10, 3.11, 3.12 (except for tensorflow users on windows)
- restrict windows users to python 3.10 for tensorflow
- Changed the CI accordingly to test for multiple python versions.
- CI uses uv now for installation.
- Added some basic testing
- Adressed copilot and other comments
@sneakers-the-rat, @MMathisLab, @C-Achard please let me know what you think and if you'd like me to incorporate additional changes.
.github/workflows/testing.yml
Outdated
| - name: Install and test | ||
| shell: bash -el {0} # Important: activates the conda environment | ||
| run: | | ||
| conda install pytables==3.8.0 "numpy<2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. This CI workflow is updated now. Let me know what you think!
pyproject.toml
Outdated
|
|
||
| [tool.poetry.dependencies] | ||
| python = ">=3.9,<3.12" | ||
| python = ">=3.10,<3.12" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been testing with 3.12 recently. It seems to work fine, except on windows with tensorflow. I've updated the pyproject.toml. Let me know if you have suggestions!
pyproject.toml
Outdated
| tables = "^3.6" | ||
| pandas = ">=1.0.1,!=1.5.0" | ||
| tables = "^3.8" | ||
| pytest = "^8.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
|
@deruyter92 Anything I can focus on/help with ? |
|
I wonder if we could remove the package-level upper bound on python version and just throw an error if someone tries to use tensorflow with python > 3.12 (or, probably, just throw an error if it cant be imported, in case someone builds their own tensorflow for >3.12). I am assiming torch works with >3.12, and it would be nice to not have that limit for downstream use, future compatibility, etc. Edit: otherwise i still have same reservations re: maintainability duplicating the dlc code here, would prefer that code is split out in such a way that the main |
| print(f"Running dog model: {model_path}") | ||
| benchmark_videos( | ||
| model_path=model_path, | ||
| model_type="base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should continue to test tensorflow models, as long as the tensorflow models remain in the benchmarking data. Is that right? There is the possibility that the benchmarking data becomes updated and we silently stop testing tensorflow, so i think it would be nice to parameterize this test to explicitly run with pytorch and tensorflow, and fail test if no tensorflow models are found.
| if has_cfg and has_pb: | ||
| return cls.TENSORFLOW | ||
| elif path.is_file(): | ||
| if path.suffix == ".pt": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wait are there serialized pytorch models now??? That would dramatically simplify things and avoid possibility for model drift - we could take out all the model code here (keeping the great work making the switchable backends!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree on your reservations regarding the duplications of the DeepLabCut code. Will try to address that.
Regarding the models: no we are not using serialized models currently. PyTorch models are exported by saving both the state dict and the model configuration in the .pt. file. In DeepLabCut-live the models are initialized using the configuration and the weights are loaded both using this exported .pt file.
(i.e. the model architecture is serialized, but the class definition must still be available in the environment when loading to reconstruct it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thats a shame. It seems like it should be possible to pickle the model classes and etc., but if the torch serialization doesnt do that, then yes lets not try and invent a file format here. Carry on
Co-authored-by: Jonny Saunders <[email protected]>
| from dlclive.engine import Engine | ||
|
|
||
| MODEL_NAME = "superanimal_quadruped" | ||
| SNAPSHOT_NAME = "snapshot-700000.pb" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that this always assumes that tf is installed - should this be updated dynamically ?
has_tf = find_spec("tensorflow") is not None
has_torch = find_spec("torch") is not None
if has_tf:
model_type = "base"
MODEL_NAME = "superanimal_quadruped"
SNAPSHOT_NAME = "snapshot-700000.pb"
elif has_torch:
model_type = "pytorch"
MODEL_NAME = # "pt model"
SNAPSHOT_NAME = # "model.pth"
else:
raise RuntimeError("Neither TensorFlow nor PyTorch is installed. Please install one using `tf` or `pytorch` optional dependencies.")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this should have thia exception be raised in one place in the package - if having neither should be considered an exception, it should get raised predictably
|
One thing I'd like to add before merging is improved modelzoo functionality where PyTorch models are exported in DeepLabCut-live format. This requires finding the corresponding model configurations & storing them inside the .pt file. |
|
Let's leave that for a follow up PR; this is already too large |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 84 out of 89 changed files in this pull request and generated 17 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self._register_module(module=module, module_name=name, force=force) | ||
| return module | ||
|
|
||
| return |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The register_module decorator function returns None when used as a decorator. It should return the _register function to work properly as a decorator. Change line 330 from 'return' to 'return _register'.
| pool1 = torch.nn.MaxPool2d(3, 1, 1) # TODO JR 01/2026: Are these unused variables informative? | ||
| pool2 = torch.nn.MaxPool2d(5, 1, 2) | ||
| pool3 = torch.nn.MaxPool2d(7, 1, 3) |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variables pool1 and pool3 are defined but never used. Only pool2 is used on line 270. Either remove the unused variables or clarify their purpose if they're placeholders for future use.
|
|
||
| SK_IM = True | ||
| except Exception: | ||
| except ImportError as e: |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception variable 'e' is caught but never used. Either remove the variable binding (use 'except ImportError:') or use it for logging/debugging purposes.
|
|
||
| OPEN_CV = True | ||
| except Exception: | ||
| except ImportError as e: |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception variable 'e' is caught but never used. Either remove the variable binding (use 'except ImportError:') or use it for logging/debugging purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These exceptions are strange to me in a different way - why are we catching import errors for packages listed in the dependencies or trying to import packages that are not listed in the deps?
| Returns | ||
| ------- | ||
| :class:`numpy.ndarray` | ||
| :class: `numpy.ndarray` |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spacing in reStructuredText class reference from ':class: numpy.ndarray' to ':class:numpy.ndarray'.
| Tests for the Engine class - engine detection and model type handling | ||
| """ | ||
| import pytest | ||
| from pathlib import Path |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Path' is not used.
| Tests for the Factory module - runner building | ||
| """ | ||
| import pytest | ||
| from pathlib import Path |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Path' is not used.
| from pathlib import Path | ||
| from unittest.mock import Mock, patch | ||
| from dlclive import factory | ||
| from dlclive.engine import Engine |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Engine' is not used.
| """ | ||
| Tests for Processor base class | ||
| """ | ||
| import pytest |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'pytest' is not used.
| """ | ||
| import pytest | ||
| import numpy as np | ||
| from pathlib import Path |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Path' is not used.
This pull requests updates DeepLabCut-Live for models exported with DeepLabCut 3.0. TensorFlow models can still be used, and the code is siloed so that only the engine used to run the code is required as a package (i.e. no need to install TensorFlow if you want to run live pose estimation with PyTorch models).
If you want to give this PR a try, you can install the code in your local
condaenvironment by running:pip install "git+https://github.com/DeepLabCut/DeepLabCut-live.git@dlclive3"