33use std:: fmt;
44use std:: path:: PathBuf ;
55
6- use anyhow:: { bail, format_err, Context as _ , Error , Result } ;
6+ use anyhow:: { bail, format_err, Error , Result } ;
77use chrono:: { DateTime , FixedOffset , TimeZone as _} ;
88use git2:: { Commit , Repository , Time } ;
99use if_chain:: if_chain;
@@ -272,30 +272,28 @@ pub fn git_repo_head_ref(repo: &git2::Repository) -> Result<String> {
272272}
273273
274274pub fn git_repo_base_ref ( repo : & git2:: Repository , remote_name : & str ) -> Result < String > {
275- // Get the current HEAD commit
276- let head_commit = repo. head ( ) ?. peel_to_commit ( ) ?;
277-
278- // Try to find the remote tracking branch
279275 let remote_branch_name = format ! ( "refs/remotes/{remote_name}/HEAD" ) ;
280276 let remote_ref = repo. find_reference ( & remote_branch_name) . map_err ( |e| {
281277 anyhow:: anyhow!( "Could not find remote tracking branch for {remote_name}: {e}" )
282278 } ) ?;
283279
284- find_merge_base_ref ( repo, & head_commit, & remote_ref)
285- }
286-
287- fn find_merge_base_ref (
288- repo : & git2:: Repository ,
289- head_commit : & git2:: Commit ,
290- remote_ref : & git2:: Reference ,
291- ) -> Result < String > {
292- let remote_commit = remote_ref. peel_to_commit ( ) ?;
293- let merge_base_oid = repo. merge_base ( head_commit. id ( ) , remote_commit. id ( ) ) ?;
280+ let name = remote_ref
281+ . resolve ( ) ?
282+ . shorthand ( )
283+ . ok_or ( anyhow:: anyhow!(
284+ "Could not find remote tracking branch for {remote_name}"
285+ ) ) ?
286+ . to_owned ( ) ;
287+
288+ let expected_prefix = format ! ( "{remote_name}/" ) ;
289+ if let Some ( branch_name) = name. strip_prefix ( & expected_prefix) {
290+ return Ok ( branch_name. to_owned ( ) ) ;
291+ } else {
292+ return Err ( anyhow:: anyhow!(
293+ "Remote branch name '{name}' does not start with expected prefix '{expected_prefix}'"
294+ ) ) ;
295+ }
294296
295- // Return the merge-base commit SHA as the base reference
296- let merge_base_sha = merge_base_oid. to_string ( ) ;
297- debug ! ( "Found merge-base commit as base reference: {merge_base_sha}" ) ;
298- Ok ( merge_base_sha)
299297}
300298
301299/// Like git_repo_base_repo_name but preserves the original case of the repository name.
@@ -569,13 +567,31 @@ pub fn find_head_sha() -> Result<String> {
569567 Ok ( head. id ( ) . to_string ( ) )
570568}
571569
572- pub fn find_base_sha ( ) -> Result < Option < String > > {
573- let github_event = std:: env:: var ( "GITHUB_EVENT_PATH" )
574- . map_err ( Error :: from)
575- . and_then ( |event_path| std:: fs:: read_to_string ( event_path) . map_err ( Error :: from) )
576- . context ( "Failed to read GitHub event path" ) ?;
570+ pub fn find_base_sha ( remote_name : & str ) -> Result < Option < String > > {
571+ if let Some ( pr_base_sha) = std:: env:: var ( "GITHUB_EVENT_PATH" )
572+ . ok ( )
573+ . and_then ( |event_path| std:: fs:: read_to_string ( event_path) . ok ( ) )
574+ . and_then ( |content| extract_pr_base_sha_from_event ( & content) )
575+ {
576+ debug ! ( "Using GitHub Actions PR base SHA from event payload: {pr_base_sha}" ) ;
577+ return Ok ( Some ( pr_base_sha) ) ;
578+ }
577579
578- extract_pr_base_sha_from_event ( & github_event)
580+ let repo = git2:: Repository :: open_from_env ( ) ?;
581+
582+ let head_commit = repo. head ( ) ?. peel_to_commit ( ) ?;
583+
584+ // Try to find the remote tracking branch
585+ let remote_branch_name = format ! ( "refs/remotes/{remote_name}/HEAD" ) ;
586+ let remote_ref = repo. find_reference ( & remote_branch_name) . map_err ( |e| {
587+ anyhow:: anyhow!( "Could not find remote tracking branch for {remote_name}: {e}" )
588+ } ) ?;
589+
590+ let remote_commit = remote_ref. peel_to_commit ( ) ?;
591+ let merge_base_oid = repo. merge_base ( head_commit. id ( ) , remote_commit. id ( ) ) ?;
592+ let merge_base_sha = merge_base_oid. to_string ( ) ;
593+ debug ! ( "Found merge-base commit as base reference: {merge_base_sha}" ) ;
594+ Ok ( Some ( merge_base_sha) )
579595}
580596
581597/// Extracts the PR head SHA from GitHub Actions event payload JSON.
@@ -595,15 +611,19 @@ fn extract_pr_head_sha_from_event(json_content: &str) -> Option<String> {
595611}
596612
597613/// Extracts the PR base SHA from GitHub Actions event payload JSON.
598- /// Returns Ok(None) if not a PR event or if SHA cannot be extracted.
599- /// Returns an error if we cannot parse the JSON.
600- fn extract_pr_base_sha_from_event ( json_content : & str ) -> Result < Option < String > > {
601- let v: Value = serde_json:: from_str ( json_content)
602- . context ( "Failed to parse GitHub event payload as JSON" ) ?;
614+ /// Returns None if not a PR event or if SHA cannot be extracted.
615+ fn extract_pr_base_sha_from_event ( json_content : & str ) -> Option < String > {
616+ let v: Value = match serde_json:: from_str ( json_content) {
617+ Ok ( v) => v,
618+ Err ( _) => {
619+ debug ! ( "Failed to parse GitHub event payload as JSON" ) ;
620+ return None ;
621+ }
622+ } ;
603623
604- Ok ( v. pointer ( "/pull_request/base/sha" )
624+ v. pointer ( "/pull_request/base/sha" )
605625 . and_then ( |s| s. as_str ( ) )
606- . map ( |s| s. to_owned ( ) ) )
626+ . map ( |s| s. to_owned ( ) )
607627}
608628
609629/// Given commit specs, repos and remote_name this returns a list of head
@@ -1705,7 +1725,7 @@ mod tests {
17051725 . to_string ( ) ;
17061726
17071727 assert_eq ! (
1708- extract_pr_base_sha_from_event( & pr_json) . unwrap ( ) ,
1728+ extract_pr_base_sha_from_event( & pr_json) ,
17091729 Some ( "55e6bc8c264ce95164314275d805f477650c440d" . to_owned( ) )
17101730 ) ;
17111731
@@ -1719,10 +1739,7 @@ mod tests {
17191739 } )
17201740 . to_string ( ) ;
17211741
1722- assert_eq ! ( extract_pr_base_sha_from_event( & push_json) . unwrap( ) , None ) ;
1723-
1724- // Test with malformed JSON
1725- assert ! ( extract_pr_base_sha_from_event( "invalid json {" ) . is_err( ) ) ;
1742+ assert_eq ! ( extract_pr_base_sha_from_event( & push_json) , None ) ;
17261743
17271744 // Test with missing base SHA
17281745 let incomplete_json = r#"{
@@ -1733,10 +1750,7 @@ mod tests {
17331750 }
17341751}"# ;
17351752
1736- assert_eq ! (
1737- extract_pr_base_sha_from_event( incomplete_json) . unwrap( ) ,
1738- None
1739- ) ;
1753+ assert_eq ! ( extract_pr_base_sha_from_event( incomplete_json) , None ) ;
17401754 }
17411755
17421756 #[ test]
@@ -1765,15 +1779,15 @@ mod tests {
17651779 fs:: write ( & event_file, pr_json) . expect ( "Failed to write event file" ) ;
17661780 std:: env:: set_var ( "GITHUB_EVENT_PATH" , event_file. to_str ( ) . unwrap ( ) ) ;
17671781
1768- let result = find_base_sha ( ) ;
1782+ let result = find_base_sha ( "origin/main" ) ;
17691783 assert_eq ! (
17701784 result. unwrap( ) . unwrap( ) ,
17711785 "55e6bc8c264ce95164314275d805f477650c440d"
17721786 ) ;
17731787
17741788 // Test without GITHUB_EVENT_PATH
17751789 std:: env:: remove_var ( "GITHUB_EVENT_PATH" ) ;
1776- let result = find_base_sha ( ) ;
1790+ let result = find_base_sha ( "origin/main" ) ;
17771791 assert ! ( result. is_err( ) ) ;
17781792 }
17791793}
0 commit comments