Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/github_actions_tracing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from requests.structures import CaseInsensitiveDict
from watcloud_utils.typer import app

import base64
import yaml

from vendor.generated import perfetto_trace_pb2

from ._version import __version__
Expand Down Expand Up @@ -42,6 +45,29 @@ def run_graphql_query(query, token):
return response.json()


def get_dependencies(data):
"""
Retrieve workflow job dependencies from UTF-8 decoded file content.
"""

# Parse workflow file as YAML document
try:
workflow = yaml.safe_load(data)
except yaml.YAMLError as e:
raise ValueError(f"Failed to parse workflow YAML: {e}")

dependencies = {}
jobs = workflow.get("jobs", {})

for job, config in jobs.items():
name = config.get("name", job)
needs = config.get("needs", [])
needs = [needs] if not isinstance(needs, list) else needs
dependencies[name] = [jobs[n].get("name", n) for n in needs]

return dependencies


@app.command()
def get_data(github_url, github_token=None):
"""
Expand Down Expand Up @@ -85,10 +111,27 @@ def get_data(github_url, github_token=None):
jobs_data = jobs_response.json().get("jobs", [])
if not jobs_data:
raise ValueError("Failed to retrieve jobs data.")

# Extract file path, commit SHA from workflow run details
ref = run_data["head_sha"]
path = run_data["path"].split("@")[0] if "@" in run_data["path"] else run_data["path"]

# Retrieve, decode Base64 encoded file content
content_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={ref}"
content_response = requests.get(content_url, headers=headers)

try:
content_response.raise_for_status()
encoded_content_data = content_response.json().get("content")
decoded_content_data = base64.b64decode(encoded_content_data).decode("utf-8")
dependencies_data = get_dependencies(decoded_content_data)
except Exception:
dependencies_data = {}

return {
"run": run_data,
"jobs": jobs_data,
"dependencies": dependencies_data,
}


Expand Down Expand Up @@ -144,6 +187,19 @@ def create_trace_file(

jobs = sorted(data["jobs"], key=lambda job: job["created_at"])

# ID lookup via job name
ids = {job["name"]: job["id"] for job in jobs}

# IDs of flows originating at dependency, ending at dependent job
flow_uuids = {}
for dependent, dependencies in data["dependencies"].items():
if dependent not in ids:
continue
for dependency in dependencies:
if dependency not in ids:
continue
flow_uuids[(dependency, dependent)] = generate_uuid(ids[dependency], ids[dependent])

for job_i, job in enumerate(jobs):
if job["conclusion"] in ["skipped"]:
continue
Expand Down Expand Up @@ -184,6 +240,13 @@ def create_trace_file(
job_create_packet.track_event.name = job["name"]
job_create_packet.track_event.categories.extend(["job", "slice"])

# List of flow ids which should terminate on this event
# https://perfetto.dev/docs/reference/trace-packet-proto
for dependency in data["dependencies"].get(job["name"], []):
flow_uuid = flow_uuids.get((dependency, job["name"]))
if flow_uuid:
job_create_packet.track_event.terminating_flow_ids.append(flow_uuid)

job_waiting_start_packet = trace.packet.add()
job_waiting_start_packet.trusted_packet_sequence_id = TRUSTED_PACKET_SEQUENCE_ID
job_waiting_start_packet.timestamp = job_create_ns
Expand Down Expand Up @@ -250,6 +313,10 @@ def create_trace_file(
perfetto_trace_pb2.TrackEvent.Type.TYPE_SLICE_END
)

for (dependency, _), flow_uuid in flow_uuids.items():
if dependency == job["name"]:
job_end_packet.track_event.flow_ids.append(flow_uuid)

# Write the trace to a JSON file for debugging
if output_debug_json:
with open(output_debug_json, "w") as f:
Expand Down
Loading