Skip to content

Push negotiation callback #1396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion pygit2/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from pygit2._libgit2.ffi import GitProxyOptionsC

from ._pygit2 import CloneOptions, PushOptions
from .remotes import TransferProgress
from .remotes import PushUpdate, TransferProgress
#
# The payload is the way to pass information from the pygit2 API, through
# libgit2, to the Python callbacks. And back.
Expand Down Expand Up @@ -198,6 +198,15 @@ def certificate_check(self, certificate: None, valid: bool, host: bytes) -> bool

raise Passthrough

def push_negotiation(self, updates: list['PushUpdate']) -> None:
"""
During a push, called once between the negotiation step and the upload.
Provides information about what updates will be performed.

Override with your own function to check the pending updates
and possibly reject them (by raising an exception).
"""

def transfer_progress(self, stats: 'TransferProgress') -> None:
"""
During the download of new data, this will be regularly called with
Expand Down Expand Up @@ -427,6 +436,7 @@ def git_push_options(payload, opts=None):
opts.callbacks.credentials = C._credentials_cb
opts.callbacks.certificate_check = C._certificate_check_cb
opts.callbacks.push_update_reference = C._push_update_reference_cb
opts.callbacks.push_negotiation = C._push_negotiation_cb
# Per libgit2 sources, push_transfer_progress may incur a performance hit.
# So, set it only if the user has overridden the no-op stub.
if (
Expand Down Expand Up @@ -559,6 +569,19 @@ def _credentials_cb(cred_out, url, username, allowed, data):
return 0


@libgit2_callback
def _push_negotiation_cb(updates, num_updates, data):
from .remotes import PushUpdate

push_negotiation = getattr(data, 'push_negotiation', None)
if not push_negotiation:
return 0

py_updates = [PushUpdate(updates[i]) for i in range(num_updates)]
push_negotiation(py_updates)
return 0


@libgit2_callback
def _push_update_reference_cb(ref, msg, data):
push_update_reference = getattr(data, 'push_update_reference', None)
Expand Down
5 changes: 5 additions & 0 deletions pygit2/decl/callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ extern "Python" int _push_update_reference_cb(
const char *status,
void *data);

extern "Python" int _push_negotiation_cb(
const git_push_update **updates,
size_t len,
void *data);

extern "Python" int _remote_create_cb(
git_remote **out,
git_repository *repo,
Expand Down
40 changes: 38 additions & 2 deletions pygit2/remotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ class LsRemotesDict(TypedDict):
oid: Oid


class PushUpdate:
"""
Represents an update which will be performed on the remote during push.
"""

src_refname: str
"""The source name of the reference"""

dst_refname: str
"""The name of the reference to update on the server"""

src: Oid
"""The current target of the reference"""

dst: Oid
"""The new target for the reference"""

def __init__(self, c_struct: Any) -> None:
src_refname = maybe_string(c_struct.src_refname)
dst_refname = maybe_string(c_struct.dst_refname)
assert src_refname is not None, 'libgit2 returned null src_refname'
assert dst_refname is not None, 'libgit2 returned null dst_refname'
self.src_refname = src_refname
self.dst_refname = dst_refname
self.src = Oid(raw=bytes(ffi.buffer(c_struct.src.id)[:]))
self.dst = Oid(raw=bytes(ffi.buffer(c_struct.dst.id)[:]))


class TransferProgress:
"""Progress downloading and indexing data during a fetch."""

Expand Down Expand Up @@ -196,7 +224,10 @@ def fetch(
return TransferProgress(C.git_remote_stats(self._remote))

def ls_remotes(
self, callbacks: RemoteCallbacks | None = None, proxy: str | None | bool = None
self,
callbacks: RemoteCallbacks | None = None,
proxy: str | None | bool = None,
connect: bool = True,
) -> list[LsRemotesDict]:
"""
Return a list of dicts that maps to `git_remote_head` from a
Expand All @@ -207,9 +238,14 @@ def ls_remotes(
callbacks : Passed to connect()

proxy : Passed to connect()

connect : Whether to connect to the remote first. You can pass False
if the remote has already connected. The list remains available after
disconnecting as long as a new connection is not initiated.
"""

self.connect(callbacks=callbacks, proxy=proxy)
if connect:
self.connect(callbacks=callbacks, proxy=proxy)

refs = ffi.new('git_remote_head ***')
refs_len = ffi.new('size_t *')
Expand Down
64 changes: 61 additions & 3 deletions test/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import pygit2
from pygit2 import Remote, Repository
from pygit2.remotes import TransferProgress
from pygit2.remotes import PushUpdate, TransferProgress

from . import utils

Expand Down Expand Up @@ -204,6 +204,22 @@ def test_ls_remotes(testrepo: Repository) -> None:
assert next(iter(r for r in refs if r['name'] == 'refs/tags/v0.28.2'))


@utils.requires_network
def test_ls_remotes_without_implicit_connect(testrepo: Repository) -> None:
assert 1 == len(testrepo.remotes)
remote = testrepo.remotes[0]

with pytest.raises(pygit2.GitError, match='this remote has never connected'):
remote.ls_remotes(connect=False)

remote.connect()
refs = remote.ls_remotes(connect=False)
assert refs

# Check that a known ref is returned.
assert next(iter(r for r in refs if r['name'] == 'refs/tags/v0.28.2'))


def test_remote_collection(testrepo: Repository) -> None:
remote = testrepo.remotes['origin']
assert REMOTE_NAME == remote.name
Expand Down Expand Up @@ -406,9 +422,12 @@ def push_transfer_progress(
assert origin.branches['master'].target == new_tip_id


@pytest.mark.parametrize('reject_from', ['push_transfer_progress', 'push_negotiation'])
def test_push_interrupted_from_callbacks(
origin: Repository, clone: Repository, remote: Remote
origin: Repository, clone: Repository, remote: Remote, reject_from: str
) -> None:
reject_message = 'retreat! retreat!'

tip = clone[clone.head.target]
clone.create_commit(
'refs/heads/master',
Expand All @@ -420,10 +439,15 @@ def test_push_interrupted_from_callbacks(
)

class MyCallbacks(pygit2.RemoteCallbacks):
def push_negotiation(self, updates: list[PushUpdate]) -> None:
if reject_from == 'push_negotiation':
raise InterruptedError(reject_message)

def push_transfer_progress(
self, objects_pushed: int, total_objects: int, bytes_pushed: int
) -> None:
raise InterruptedError('retreat! retreat!')
if reject_from == 'push_transfer_progress':
raise InterruptedError(reject_message)

assert origin.branches['master'].target == tip.id

Expand Down Expand Up @@ -504,3 +528,37 @@ def test_push_threads(origin: Repository, clone: Repository, remote: Remote) ->
callbacks = RemoteCallbacks()
remote.push(['refs/heads/master'], callbacks, threads=1)
assert callbacks.push_options.pb_parallelism == 1


def test_push_negotiation(
origin: Repository, clone: Repository, remote: Remote
) -> None:
old_tip = clone[clone.head.target]
new_tip_id = clone.create_commit(
'refs/heads/master',
old_tip.author,
old_tip.author,
'empty commit',
old_tip.tree.id,
[old_tip.id],
)

the_updates: list[PushUpdate] = []

class MyCallbacks(pygit2.RemoteCallbacks):
def push_negotiation(self, updates: list[PushUpdate]) -> None:
the_updates.extend(updates)

assert origin.branches['master'].target == old_tip.id
assert 'new_branch' not in origin.branches

callbacks = MyCallbacks()
remote.push(['refs/heads/master'], callbacks=callbacks)

assert len(the_updates) == 1
assert the_updates[0].src_refname == 'refs/heads/master'
assert the_updates[0].dst_refname == 'refs/heads/master'
assert the_updates[0].src == old_tip.id
assert the_updates[0].dst == new_tip_id

assert origin.branches['master'].target == new_tip_id
Loading