Skip to content

Commit

Permalink
generate_patches.py: Add unidiff() to GitRepo
Browse files Browse the repository at this point in the history
This method is copied from utils.git module in ncusi/MSR_Challenge_2024
https://github.com/ncusi/MSR_Challenge_2024/blob/main/src/utils/git.py

It is a start of series of revisions that would introduce API for
generating the diff annotation processing pipeline.  The GitRepo class
will be able to generate list of unidiff.PatchSets, and pass it to
annotation process, instead of having to save patches to disk.
  • Loading branch information
jnareb committed Jul 12, 2024
1 parent 4810f9b commit 897ce4b
Showing 1 changed file with 67 additions and 2 deletions.
69 changes: 67 additions & 2 deletions src/diffannotator/generate_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import re
import subprocess
from pathlib import Path
from typing import Optional, Union, TypeVar
from typing import Optional, Union, TypeVar, overload, Literal
from typing import Iterable # should be imported from collections.abc

import typer
from typing_extensions import Annotated
from unidiff import PatchSet

# TODO: move to __init__.py (it is common to all scripts)
PathLike = TypeVar("PathLike", str, bytes, Path, os.PathLike)
Expand All @@ -30,6 +31,11 @@ class GitRepo:
path_encoding = 'utf8'
default_file_encoding = 'utf8'
log_encoding = 'utf8'
fallback_encoding = 'latin1' # must be 8-bit encoding
# see 346245a1bb ("hard-code the empty tree object", 2008-02-13)
# https://github.com/git/git/commit/346245a1bb6272dd370ba2f7b9bf86d3df5fed9a
# https://github.com/git/git/commit/e1ccd7e2b1cae8d7dab4686cddbd923fb6c46953
empty_tree_sha1 = '4b825dc642cb6eb9a060e54bf8d69288fbee4904'

def __init__(self, repo_dir: PathLike):
"""Constructor for `GitRepo` class
Expand Down Expand Up @@ -182,6 +188,65 @@ def format_patch(self,
else:
return process.stderr

@overload
def unidiff(self, commit: str = ..., prev: Optional[str] = ..., wrap: Literal[True] = ...) -> PatchSet:
...

@overload
def unidiff(self, commit: str = ..., prev: Optional[str] = ..., *, wrap: Literal[False]) -> Union[str, bytes]:
...

@overload
def unidiff(self, commit: str = ..., prev: Optional[str] = ..., wrap: bool = ...) -> Union[str, bytes, PatchSet]:
...

def unidiff(self, commit='HEAD', prev=None, wrap=True):
"""Return unified diff between `commit` and `prev`
If `prev` is None (which is the default), return diff between the
`commit` and its first parent, or between the `commit` and the empty
tree if `commit` does not have any parents (if it is a root commit).
If `wrap` is True (which is the default), wrap the result in
unidiff.PatchSet to make it easier to extract information from
the diff. Otherwise, return diff as plain text.
:param str commit: later (second) of two commits to compare,
defaults to 'HEAD', that is the current commit
:param prev: earlier (first) of two commits to compare,
defaults to None, which means comparing to parent of `commit`
:type prev: str or None
:param bool wrap: whether to wrap the result in PatchSet
:return: the changes between two arbitrary commits,
`prev` and `commit`
:rtype: str or bytes or PatchSet
"""
if prev is None:
try:
# NOTE: this means first-parent changes for merge commits
return self.unidiff(commit=commit, prev=commit + '^', wrap=wrap)
except subprocess.CalledProcessError:
# commit^ does not exist for a root commits (for first commits)
return self.unidiff(commit=commit, prev=self.empty_tree_sha1, wrap=wrap)

cmd = [
'git', '-C', self.repo,
'diff', '--find-renames', '--find-copies', '--find-copies-harder',
prev, commit
]
process = subprocess.run(cmd,
capture_output=True, check=True)
try:
diff_output = process.stdout.decode(self.default_file_encoding)
except UnicodeDecodeError:
# unidiff.PatchSet can only handle strings
diff_output = process.stdout.decode(self.fallback_encoding)

if wrap:
return PatchSet(diff_output)
else:
return diff_output


app = typer.Typer(no_args_is_help=True, add_completion=False)

Expand All @@ -190,7 +255,7 @@ def format_patch(self,
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
)
def main(
ctx: typer.Context,
ctx: typer.Context,
repo_path: Annotated[
Path,
typer.Argument(
Expand Down

0 comments on commit 897ce4b

Please sign in to comment.