Skip to content

Conversation

@mathemakitten
Copy link
Contributor

What does this PR do ?

Previously, setting args.rl_offload_kv_cache_during_training would offload the KV cache to CPU to save memory while switching to training mode and reload it back onto device. The new KV cache tensor would be instantiated with a new memory address, which necessitated rebuilding inference cudagraphs on every step.

Orthogonally, training cudagraphs previously could only be used when the KV cache, inference activations, and training activations could all be co-located within HBM.

Now, we suspend the physical memory underlying the KV cache tensor and offload the tensor contents to CPU when transitioning from inference to training, using torch_memory_saver. Notably, the virtual memory address of the KV cache tensor is kept consistent. Training cudagraphs can now be captured while accounting for this offload: since the KV cache tensor has been offloaded with torch_memory_saver.pause(), its CUDA allocation is released back to PyTorch’s caching allocator, so it is safe to reuse this memory for other training-side tensors.

When transitioning from training to inference, physical memory is simply re-bound to the same virtual address. This means that cudagraphs no longer need to be disabled between phases, since the KV cache has a consistent virtual address despite being offloaded/onloaded. This is compliant with partial rollouts.

The result is a minor increase in max reserved memory, with lower peak allocated memory and a consistent memory usage pattern over many iterations. This replaces the old behaviour of args.rl_offload_kv_cache_during_training.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@tdene
Copy link
Contributor

tdene commented Jan 22, 2026

For now, can we preserve both the UVM and torch memory saver functionality, since it seems like very minimal support to keep both?

You'd have to make the torch memory saver context not be used if UVM is turned on by directly changing this block in dyamic_context.py:

        ctx_manager = (
            torch.cuda.use_mem_pool(self.unified_memory_mempool)
            if self.unified_memory_level > 0
            else nullcontext()
        )

to make the else portion be torch_memory_saver.region(tag="inference_context", enable_cpu_backup=True)

And you'd have to add to the if statement in the argument validation:

    if args.rl_offload_kv_cache_during_training and not args.inference_dynamic_batching_unified_memory_level:
        try:
            from torch_memory_saver import torch_memory_saver
        except ImportError:
            raise AssertionError("To use offload-kv-cache-during-training, `torch_memory_saver` must be installed. See https://github.com/fzyzcjy/torch_memory_saver.")

EDIT: Mainly because I think we should run perf benchmarks with both options to check there's no unexpected behavior.

args.data_parallel_size = args.world_size // total_model_size

# Assert that `torch_memory_saver` is installed if offloading KV cache during RL.
if args.rl_offload_kv_cache_during_training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably want to add import check somewhere in core as well since its used there in the context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but shouldn't we raise an error if someone thinks we are using it? Seems like secretly moving to a null context if its not available will give users unexpected behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an assert not self.offload_kv_cache and self.unified_memory_level in megatron/core/inference/contexts/dynamic_context.py for core. The check for self.offload_kv_cache is now sufficient; we need to keep the nullcontext path because we use it in the unified memory case.

Copy link
Contributor

@wdykas wdykas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comments but LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants