|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Update the third_party/xla/workspace.bzl file to use the given XLA commit""" |
| 3 | + |
| 4 | +import argparse |
| 5 | +import hashlib |
| 6 | +import logging |
| 7 | +import os.path |
| 8 | +import re |
| 9 | +import subprocess |
| 10 | + |
| 11 | +import requests |
| 12 | + |
| 13 | +GH_COMMIT_URL = "https://api.github.com/repos/{0}/commits/{1}" |
| 14 | +GH_BASE_URL = "https://github.com" |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +def update_xla_hash(xla_commit, xla_repo, workspace_file_path, gh_token): |
| 19 | + # Verify that the workspace_file exists |
| 20 | + if not os.path.isfile(workspace_file_path): |
| 21 | + raise ValueError(f"Workspace file '{workspace_file}' does not exist") |
| 22 | + |
| 23 | + # If we were given a GH auth token, use it to make sure that the commit |
| 24 | + # exists and convert a branch name to a commit hash |
| 25 | + if gh_token: |
| 26 | + logger.debug(GH_COMMIT_URL.format(xla_repo, xla_commit)) |
| 27 | + commit_info_resp = requests.get( |
| 28 | + url=GH_COMMIT_URL.format(xla_repo, xla_commit), |
| 29 | + headers={ |
| 30 | + "Accept": "application/vnd.github.sha", |
| 31 | + "Authorization": f"Bearer {gh_token}", |
| 32 | + "X-Github-Api-Version": "2022-11-28", |
| 33 | + }, |
| 34 | + ) |
| 35 | + commit_info_resp.raise_for_status() |
| 36 | + logger.info("Found commit hash via GH API: %s", commit_info_resp.text) |
| 37 | + xla_commit_hash = commit_info_resp.text.strip() |
| 38 | + # If the user didn't give us a token make sure the commit hash looks hashy |
| 39 | + else: |
| 40 | + if not xla_commit.isalnum(): |
| 41 | + raise ValueError( |
| 42 | + f"XLA commit hash '{xla_commit}' is not a valid commit hash" |
| 43 | + ) |
| 44 | + xla_commit_hash = xla_commit |
| 45 | + |
| 46 | + # Get the sha256 of this commit |
| 47 | + xla_zip_resp = requests.get( |
| 48 | + f"{GH_BASE_URL}/{xla_repo}/archive/{xla_commit_hash}.tar.gz" |
| 49 | + ) |
| 50 | + xla_zip_resp.raise_for_status() |
| 51 | + hasher = hashlib.sha256() |
| 52 | + hasher.update(xla_zip_resp.content) |
| 53 | + sha256_hex = hasher.hexdigest().strip() |
| 54 | + logger.info("sha256: %s", sha256_hex) |
| 55 | + |
| 56 | + # Open the workspace file |
| 57 | + with open(workspace_file_path, "r+") as workspace_file: |
| 58 | + contents = workspace_file.read() |
| 59 | + # Edit the commit hash, sha256 hash, and repo |
| 60 | + contents = re.sub( |
| 61 | + 'XLA_COMMIT = "[a-z0-9]*"', |
| 62 | + f'XLA_COMIT = "{xla_commit_hash}"', |
| 63 | + contents, |
| 64 | + flags=re.M, |
| 65 | + ) |
| 66 | + contents = re.sub( |
| 67 | + 'XLA_SHA256 = "[a-z0-9]*"', |
| 68 | + f'XLA_SHA256 = "{sha256_hex}"', |
| 69 | + contents, |
| 70 | + flags=re.M, |
| 71 | + ) |
| 72 | + contents = re.sub( |
| 73 | + 'tf_mirror_urls\("[a-zA-Z0-9:/.]+/archive', |
| 74 | + f'tf_mirror_urls("{GH_BASE_URL}/{xla_repo}/archive', |
| 75 | + contents, |
| 76 | + flags=re.M, |
| 77 | + ) |
| 78 | + # Write to the workspace file |
| 79 | + workspace_file.seek(0) |
| 80 | + workspace_file.write(contents) |
| 81 | + workspace_file.truncate() |
| 82 | + |
| 83 | + |
| 84 | +def parse_args(): |
| 85 | + arg_parser = argparse.ArgumentParser( |
| 86 | + description="Update the XLA commit hash in the workspace.bzl file" |
| 87 | + ) |
| 88 | + arg_parser.add_argument( |
| 89 | + "xla_commit", |
| 90 | + help="Branch or commit to put in the workspace file", |
| 91 | + ) |
| 92 | + arg_parser.add_argument( |
| 93 | + "--gh-token", |
| 94 | + help="Github token to authenticate with. Either the GIHUB_TOKEN from Actions or your PAT.", |
| 95 | + ) |
| 96 | + arg_parser.add_argument( |
| 97 | + "--xla-repo", |
| 98 | + default="openxla/xla", |
| 99 | + help="The repo where this branch or commit can be found. Should be in the form of <owner>/<repo>. Defaults to openxla/xla.", |
| 100 | + ) |
| 101 | + arg_parser.add_argument( |
| 102 | + "--workspace-file", |
| 103 | + default=".jax_rocm_plugin/third_party/xla/workspace.bzl", |
| 104 | + help="Path to the workspace.bzl file to put the hash. Defaults to ./third_party/xla/workspace.bzl.", |
| 105 | + ) |
| 106 | + arg_parser.add_argument( |
| 107 | + "-v", |
| 108 | + "--verbose", |
| 109 | + help="Turn on debug logging", |
| 110 | + action="store_const", |
| 111 | + dest="loglevel", |
| 112 | + const=logging.DEBUG, |
| 113 | + ) |
| 114 | + return arg_parser.parse_args() |
| 115 | + |
| 116 | + |
| 117 | +if __name__ == "__main__": |
| 118 | + args = parse_args() |
| 119 | + logging.basicConfig(level=args.loglevel) |
| 120 | + update_xla_hash(args.xla_commit, args.xla_repo, args.workspace_file, args.gh_token) |
0 commit comments