33use std:: fmt;
44use std:: path:: PathBuf ;
55
6- use anyhow:: { bail, format_err, Context as _ , Error , Result } ;
6+ use anyhow:: { bail, format_err, Error , Result } ;
77use chrono:: { DateTime , FixedOffset , TimeZone as _} ;
88use git2:: { Commit , Repository , Time } ;
99use if_chain:: if_chain;
@@ -272,30 +272,27 @@ pub fn git_repo_head_ref(repo: &git2::Repository) -> Result<String> {
272272}
273273
274274pub fn git_repo_base_ref ( repo : & git2:: Repository , remote_name : & str ) -> Result < String > {
275- // Get the current HEAD commit
276- let head_commit = repo. head ( ) ?. peel_to_commit ( ) ?;
277-
278- // Try to find the remote tracking branch
279275 let remote_branch_name = format ! ( "refs/remotes/{remote_name}/HEAD" ) ;
280276 let remote_ref = repo. find_reference ( & remote_branch_name) . map_err ( |e| {
281277 anyhow:: anyhow!( "Could not find remote tracking branch for {remote_name}: {e}" )
282278 } ) ?;
283279
284- find_merge_base_ref ( repo, & head_commit, & remote_ref)
285- }
286-
287- fn find_merge_base_ref (
288- repo : & git2:: Repository ,
289- head_commit : & git2:: Commit ,
290- remote_ref : & git2:: Reference ,
291- ) -> Result < String > {
292- let remote_commit = remote_ref. peel_to_commit ( ) ?;
293- let merge_base_oid = repo. merge_base ( head_commit. id ( ) , remote_commit. id ( ) ) ?;
294-
295- // Return the merge-base commit SHA as the base reference
296- let merge_base_sha = merge_base_oid. to_string ( ) ;
297- debug ! ( "Found merge-base commit as base reference: {merge_base_sha}" ) ;
298- Ok ( merge_base_sha)
280+ let name = remote_ref
281+ . resolve ( ) ?
282+ . shorthand ( )
283+ . ok_or ( anyhow:: anyhow!(
284+ "Could not find remote tracking branch for {remote_name}"
285+ ) ) ?
286+ . to_owned ( ) ;
287+
288+ let expected_prefix = format ! ( "{remote_name}/" ) ;
289+ if let Some ( branch_name) = name. strip_prefix ( & expected_prefix) {
290+ Ok ( branch_name. to_owned ( ) )
291+ } else {
292+ Err ( anyhow:: anyhow!(
293+ "Remote branch name '{name}' does not start with expected prefix '{expected_prefix}'"
294+ ) )
295+ }
299296}
300297
301298/// Like git_repo_base_repo_name but preserves the original case of the repository name.
@@ -569,13 +566,31 @@ pub fn find_head_sha() -> Result<String> {
569566 Ok ( head. id ( ) . to_string ( ) )
570567}
571568
572- pub fn find_base_sha ( ) -> Result < Option < String > > {
573- let github_event = std:: env:: var ( "GITHUB_EVENT_PATH" )
574- . map_err ( Error :: from)
575- . and_then ( |event_path| std:: fs:: read_to_string ( event_path) . map_err ( Error :: from) )
576- . context ( "Failed to read GitHub event path" ) ?;
569+ pub fn find_base_sha ( remote_name : & str ) -> Result < Option < String > > {
570+ if let Some ( pr_base_sha) = std:: env:: var ( "GITHUB_EVENT_PATH" )
571+ . ok ( )
572+ . and_then ( |event_path| std:: fs:: read_to_string ( event_path) . ok ( ) )
573+ . and_then ( |content| extract_pr_base_sha_from_event ( & content) )
574+ {
575+ debug ! ( "Using GitHub Actions PR base SHA from event payload: {pr_base_sha}" ) ;
576+ return Ok ( Some ( pr_base_sha) ) ;
577+ }
578+
579+ let repo = git2:: Repository :: open_from_env ( ) ?;
580+
581+ let head_commit = repo. head ( ) ?. peel_to_commit ( ) ?;
577582
578- extract_pr_base_sha_from_event ( & github_event)
583+ // Try to find the remote tracking branch
584+ let remote_branch_name = format ! ( "refs/remotes/{remote_name}/HEAD" ) ;
585+ let remote_ref = repo. find_reference ( & remote_branch_name) . map_err ( |e| {
586+ anyhow:: anyhow!( "Could not find remote tracking branch for {remote_name}: {e}" )
587+ } ) ?;
588+
589+ let remote_commit = remote_ref. peel_to_commit ( ) ?;
590+ let merge_base_oid = repo. merge_base ( head_commit. id ( ) , remote_commit. id ( ) ) ?;
591+ let merge_base_sha = merge_base_oid. to_string ( ) ;
592+ debug ! ( "Found merge-base commit as base reference: {merge_base_sha}" ) ;
593+ Ok ( Some ( merge_base_sha) )
579594}
580595
581596/// Extracts the PR head SHA from GitHub Actions event payload JSON.
@@ -595,15 +610,19 @@ fn extract_pr_head_sha_from_event(json_content: &str) -> Option<String> {
595610}
596611
597612/// Extracts the PR base SHA from GitHub Actions event payload JSON.
598- /// Returns Ok(None) if not a PR event or if SHA cannot be extracted.
599- /// Returns an error if we cannot parse the JSON.
600- fn extract_pr_base_sha_from_event ( json_content : & str ) -> Result < Option < String > > {
601- let v: Value = serde_json:: from_str ( json_content)
602- . context ( "Failed to parse GitHub event payload as JSON" ) ?;
613+ /// Returns None if not a PR event or if SHA cannot be extracted.
614+ fn extract_pr_base_sha_from_event ( json_content : & str ) -> Option < String > {
615+ let v: Value = match serde_json:: from_str ( json_content) {
616+ Ok ( v) => v,
617+ Err ( _) => {
618+ debug ! ( "Failed to parse GitHub event payload as JSON" ) ;
619+ return None ;
620+ }
621+ } ;
603622
604- Ok ( v. pointer ( "/pull_request/base/sha" )
623+ v. pointer ( "/pull_request/base/sha" )
605624 . and_then ( |s| s. as_str ( ) )
606- . map ( |s| s. to_owned ( ) ) )
625+ . map ( |s| s. to_owned ( ) )
607626}
608627
609628/// Given commit specs, repos and remote_name this returns a list of head
@@ -1705,7 +1724,7 @@ mod tests {
17051724 . to_string ( ) ;
17061725
17071726 assert_eq ! (
1708- extract_pr_base_sha_from_event( & pr_json) . unwrap ( ) ,
1727+ extract_pr_base_sha_from_event( & pr_json) ,
17091728 Some ( "55e6bc8c264ce95164314275d805f477650c440d" . to_owned( ) )
17101729 ) ;
17111730
@@ -1719,10 +1738,7 @@ mod tests {
17191738 } )
17201739 . to_string ( ) ;
17211740
1722- assert_eq ! ( extract_pr_base_sha_from_event( & push_json) . unwrap( ) , None ) ;
1723-
1724- // Test with malformed JSON
1725- assert ! ( extract_pr_base_sha_from_event( "invalid json {" ) . is_err( ) ) ;
1741+ assert_eq ! ( extract_pr_base_sha_from_event( & push_json) , None ) ;
17261742
17271743 // Test with missing base SHA
17281744 let incomplete_json = r#"{
@@ -1733,10 +1749,7 @@ mod tests {
17331749 }
17341750}"# ;
17351751
1736- assert_eq ! (
1737- extract_pr_base_sha_from_event( incomplete_json) . unwrap( ) ,
1738- None
1739- ) ;
1752+ assert_eq ! ( extract_pr_base_sha_from_event( incomplete_json) , None ) ;
17401753 }
17411754
17421755 #[ test]
@@ -1765,15 +1778,15 @@ mod tests {
17651778 fs:: write ( & event_file, pr_json) . expect ( "Failed to write event file" ) ;
17661779 std:: env:: set_var ( "GITHUB_EVENT_PATH" , event_file. to_str ( ) . unwrap ( ) ) ;
17671780
1768- let result = find_base_sha ( ) ;
1781+ let result = find_base_sha ( "origin" ) ;
17691782 assert_eq ! (
17701783 result. unwrap( ) . unwrap( ) ,
17711784 "55e6bc8c264ce95164314275d805f477650c440d"
17721785 ) ;
17731786
17741787 // Test without GITHUB_EVENT_PATH
17751788 std:: env:: remove_var ( "GITHUB_EVENT_PATH" ) ;
1776- let result = find_base_sha ( ) ;
1789+ let result = find_base_sha ( "origin" ) ;
17771790 assert ! ( result. is_err( ) ) ;
17781791 }
17791792}
0 commit comments