Skip to content
Closed
201 changes: 68 additions & 133 deletions .github/scripts/rigging_pr_decorator.py
Original file line number Diff line number Diff line change
@@ -1,142 +1,77 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "rigging",
# "typer",
# ]
# ///

import asyncio
import base64
import os
import subprocess
import typing as t

from pydantic import ConfigDict, StringConstraints
import typer

import rigging as rg
from rigging import logger
from rigging.generator import GenerateParams, Generator, register_generator

logger.enable("rigging")

MAX_TOKENS = 8000
TRUNCATION_WARNING = "\n\n**Note**: Due to the large size of this diff, some content has been truncated."
str_strip = t.Annotated[str, StringConstraints(strip_whitespace=True)]


class PRDiffData(rg.Model):
"""XML model for PR diff data"""

content: str_strip = rg.element()

@classmethod
def xml_example(cls) -> str:
return """<diff><content>example diff content</content></diff>"""


class PRDecorator(Generator):
"""Generator for creating PR descriptions"""

model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

api_key: str = ""
max_tokens: int = MAX_TOKENS

def __init__(self, model: str, params: rg.GenerateParams) -> None:
api_key = params.extra.get("api_key")
if not api_key:
raise ValueError("api_key is required in params.extra")

super().__init__(model=model, params=params, api_key=api_key)
self.api_key = api_key
self.max_tokens = params.max_tokens or MAX_TOKENS

async def generate_messages(
self,
messages: t.Sequence[t.Sequence[rg.Message]],
params: t.Sequence[GenerateParams],
) -> t.Sequence[rg.GeneratedMessage]:
responses = []
for message_seq, p in zip(messages, params):
base_generator = rg.get_generator(self.model, params=p)
llm_response = await base_generator.generate_messages([message_seq], [p])
responses.extend(llm_response)
return responses


register_generator("pr_decorator", PRDecorator)


async def generate_pr_description(diff_text: str) -> str:
"""Generate a PR description from the diff text"""
diff_tokens = len(diff_text) // 4
if diff_tokens >= MAX_TOKENS:
char_limit = (MAX_TOKENS * 4) - len(TRUNCATION_WARNING)
diff_text = diff_text[:char_limit] + TRUNCATION_WARNING

diff_data = PRDiffData(content=diff_text)
params = rg.GenerateParams(
extra={
"api_key": os.environ["OPENAI_API_KEY"],
"diff_text": diff_text,
},
temperature=0.1,
max_tokens=500,
)

generator = rg.get_generator("pr_decorator!gpt-4-turbo-preview", params=params)
prompt = f"""You are a helpful AI that generates clear and concise PR descriptions with some pirate tongue.
Analyze the provided git diff and create a summary, specifically focusing on the elements of the code that
has changed, high severity functions etc using exactly this format:

### PR Summary

#### Overview of Changes
<overview paragraph>

#### Key Modifications
1. **<modification title>**: <description>
(continue as needed)

#### Potential Impact
- <impact point 1>
(continue as needed)

Here is the PR diff to analyze:
{diff_data.to_xml()}"""

chat = await generator.chat(prompt).run()
return chat.last.content.strip()


async def main():
"""Main function for CI environment"""
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY environment variable must be set")

try:
diff_text = os.environ.get("GIT_DIFF", "")
if not diff_text:
raise ValueError("No diff found in GIT_DIFF environment variable")

try:
diff_text = base64.b64decode(diff_text).decode("utf-8")
except Exception:
padding = 4 - (len(diff_text) % 4)
if padding != 4:
diff_text += "=" * padding
diff_text = base64.b64decode(diff_text).decode("utf-8")

logger.debug(f"Processing diff of length: {len(diff_text)}")
description = await generate_pr_description(diff_text)

with open(os.environ["GITHUB_OUTPUT"], "a") as f:
f.write("content<<EOF\n")
f.write(description)
f.write("\nEOF\n")
f.write(f"debug_diff_length={len(diff_text)}\n")
f.write(f"debug_description_length={len(description)}\n")
debug_preview = description[:500]
f.write("debug_preview<<EOF\n")
f.write(debug_preview)
f.write("\nEOF\n")

except Exception as e:
logger.error(f"Error in main: {e}")
raise
TRUNCATION_WARNING = "\n---\n**Note**: Due to the large size of this diff, some content has been truncated."


@rg.prompt
def generate_pr_description(diff: str) -> t.Annotated[str, rg.Ctx("markdown")]: # type: ignore[empty-body]
"""
Analyze the provided git diff and create a PR description in markdown format.

<guidance>
- Keep the summary concise and informative.
- Use bullet points to structure important statements.
- Focus on key modifications and potential impact - if any.
- Do not add in general advice or best-practice information.
- Write like a developer who authored the changes.
- Prefer flat bullet lists over nested.
- Do not include any title structure.
</guidance>
"""


def get_diff(target_ref: str, source_ref: str) -> str:
"""
Get the git diff between two branches.
"""

merge_base = subprocess.run(
["git", "merge-base", source_ref, target_ref],
capture_output=True,
text=True,
check=True,
).stdout.strip()
diff_text = subprocess.run(
["git", "diff", merge_base],
capture_output=True,
text=True,
check=True,
).stdout
return diff_text


def main(
target_ref: str,
source_ref: str = "HEAD",
generator_id: str = "openai/gpt-4o-mini",
max_diff_lines: int = 1000,
) -> None:
"""
Use rigging to generate a PR description from a git diff.
"""

diff = get_diff(target_ref, source_ref)
diff_lines = diff.split("\n")
if len(diff_lines) > max_diff_lines:
diff = "\n".join(diff_lines[:max_diff_lines]) + TRUNCATION_WARNING

description = asyncio.run(generate_pr_description.bind(generator_id)(diff))
print(description)


if __name__ == "__main__":
asyncio.run(main())
typer.run(main)
46 changes: 19 additions & 27 deletions .github/workflows/rigging_pr_description.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Update PR Description with Rigging

on:
pull_request:
types: [opened]
types: [opened, synchronize]

jobs:
update-description:
Expand All @@ -12,48 +12,40 @@ jobs:
contents: read

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
fetch-depth: 0 # full history for proper diffing

# Get the diff first
- name: Get Diff
id: diff
# shellcheck disable=SC2102
run: |
git fetch origin "${{ github.base_ref }}"
MERGE_BASE=$(git merge-base HEAD "origin/${{ github.base_ref }}")
# Use separate diff arguments instead of range notation
DIFF=$(git diff "$MERGE_BASE" HEAD | base64 --wrap=0)
echo "diff=${DIFF}" >> "$GITHUB_OUTPUT"
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b #v5.0.3
- name: Set up Python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.0.3
with:
python-version: "3.11"
python-version: "3.10"

- name: Install dependencies
- name: Install uv
run: |
python -m pip install --upgrade pip
pip cache purge
pip install pydantic
pip install rigging[all]
# Generate the description using the diff
pip install uv

- name: Generate PR Description
id: description
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
PR_NUMBER: ${{ github.event.pull_request.number }}
GIT_DIFF: ${{ steps.diff.outputs.diff }}
run: |
python .github/scripts/rigging_pr_decorator.py
# Update the PR description
DESCRIPTION=$(uv run --no-project .github/scripts/rigging_pr_decorator.py origin/${{ github.base_ref }})
echo "description<<EOF" >> $GITHUB_OUTPUT
echo "$DESCRIPTION" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT

- name: Update PR Description
uses: nefrob/pr-description@4dcc9f3ad5ec06b2a197c5f8f93db5e69d2fdca7 #v1.2.0
uses: nefrob/pr-description@4dcc9f3ad5ec06b2a197c5f8f93db5e69d2fdca7 # v1.2.0
with:
content: |
## AI-Generated Summary
${{ steps.description.outputs.content }}

${{ steps.description.outputs.description }}

---

This summary was generated with ❤️ by [rigging](https://rigging.dreadnode.io/)
regex: ".*"
regexFlags: s
Expand Down
13 changes: 11 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,24 @@ repos:
hooks:
- id: actionlint
name: Check Github Actions
args: ["--ignore", "SC2102"]
args: ["--ignore", "SC2102,SC2129,SC2086"]
exclude: ^rigging_pr_description\.yml$

# Python code security
- repo: https://github.com/PyCQA/bandit
rev: 8fd258abbac759d62863779f946d6a88e8eabb0f #1.8.0
hooks:
- id: bandit
name: Code security checks
args: ["-r", "--level", "2", "./"]
args:
[
-r,
.,
--severity-level,
high,
-x,
.github/scripts/rigging_pr_decorator.py,
]

- repo: local
hooks:
Expand Down
Loading