diff --git a/frg/cli.py b/frg/cli.py index 32a229b..f982d28 100644 --- a/frg/cli.py +++ b/frg/cli.py @@ -4,6 +4,7 @@ import click import pydantic import frg.forgejo.browser as forgejo_browser +import frg.git as git from frg.configuration import Config, get_configuration from frg.context import GitContext, get_git_context @@ -61,6 +62,9 @@ def pr(ctx): @click.pass_obj def create_pr(ctx, web: bool): """Interacts with pull requests.""" + + git.push(branch=ctx.git.current_branch) + if web: forgejo_browser.create_pull_request_via_web( head=ctx.git.current_branch, diff --git a/frg/context.py b/frg/context.py index 530463f..c5a2466 100644 --- a/frg/context.py +++ b/frg/context.py @@ -6,10 +6,11 @@ invocation environment (i.e. git context, ...) to make things work. """ import logging -import subprocess import pydantic +import frg.git as git + logger = logging.getLogger(__name__) @@ -56,15 +57,8 @@ def get_git_context(*, domain_aliases: dict[str, str] | None = None) -> GitConte if not domain_aliases: domain_aliases = dict() - current_branch_cmd = subprocess.run( - ["git", "branch", "--show-current"], capture_output=True - ) - current_branch = current_branch_cmd.stdout.decode("utf8").strip() - - remote_url_cmd = subprocess.run( - ["git", "config", "--get", "remote.origin.url"], capture_output=True - ) - remote_url = remote_url_cmd.stdout.decode("utf8").strip() + current_branch = git.get_current_branch().stdout + remote_url = git.get_current_remote_url().stdout host, owner, repo = parse_remote_url(remote_url) diff --git a/frg/git.py b/frg/git.py new file mode 100644 index 0000000..50ccffd --- /dev/null +++ b/frg/git.py @@ -0,0 +1,34 @@ +import subprocess + +import pydantic + + +class CommandResult(pydantic.BaseModel): + return_code: int + stdout: str + stderr: str + + +def _git(args: list[str]) -> CommandResult: + result = subprocess.run(["git", *args], capture_output=True) + + return CommandResult( + stdout=result.stdout.decode("utf8").strip(), + stderr=result.stderr.decode("utf8").strip(), + return_code=result.returncode, + ) + + +def get_current_branch() -> CommandResult: + """Returns the current checked out branch.""" + return _git(["branch", "--show-current"]) + + +def get_current_remote_url() -> CommandResult: + """Returns the remote origin url.""" + return _git(["config", "--get", "remote.origin.url"]) + + +def push(*, branch: str) -> CommandResult: + """Pushes the current local commits to remote.""" + return _git(["push", "--set-upstream", "origin", branch]) diff --git a/tests/test_context.py b/tests/test_context.py index c94afd3..58c0b33 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -28,6 +28,9 @@ def mock_subprocess(mock_context): "utf8" ) + return_val.stderr = b"" + return_val.returncode = 0 + return return_val with unittest.mock.patch("subprocess.run") as mock_run: