diff --git a/Cargo.toml b/Cargo.toml index 46100c2..ca9aad2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ base64 = "0.22" gix = { version = "0.84", features = ["max-performance", "parallel", "status", "blob-diff"] } gix-diff = "0.64" gix-status = "0.31" +gix-revwalk = "0.32" # Image processing (Kitty graphics protocol) image = { version = "0.25", default-features = false, features = ["png", "gif"] } diff --git a/src/main.rs b/src/main.rs index 69fbe30..bf4f65a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1066,65 +1066,74 @@ fn get_git_status(cwd: &str) -> Option { return None; } - // Get ahead/behind against upstream + // Get ahead/behind against upstream using git's configured upstream let mut ahead = 0usize; let mut behind = 0usize; if let Ok(head_id) = repo.head_id() { - if let Ok(head_ref) = repo.find_reference("HEAD") { - if let gix::refs::TargetRef::Symbolic(upstream_name) = head_ref.target() { - let upstream_full = format!("refs/remotes/origin/{}", upstream_name); - if let Ok(upstream_ref) = repo.find_reference(&upstream_full) { - let mut upstream_ref = upstream_ref; - if let Ok(upstream_id) = upstream_ref.peel_to_id() { - let head_id_detached = head_id.detach(); - let upstream_id_detached = upstream_id.detach(); - // Count ahead: commits reachable from head_id but not from upstream_id - let mut count = 0usize; - let mut seen = std::collections::HashSet::new(); - let mut queue = vec![head_id_detached]; - while let Some(current) = queue.pop() { - if seen.contains(¤t) { - continue; - } - seen.insert(current); - if let Ok(commit) = repo.find_commit(current.clone()) { - for parent_oid in commit.parent_ids() { - let parent_oid_detached = parent_oid.detach(); - if parent_oid_detached == upstream_id_detached { + if let Ok(Some(full_name)) = repo.head_name() { + // Use gix's branch_remote_tracking_ref_name to get the remote tracking branch + if let Some(upstream_result) = repo.branch_remote_tracking_ref_name(full_name.as_ref(), gix::remote::Direction::Fetch) { + if let Ok(upstream_ref_name) = upstream_result { + if let Ok(upstream_ref) = repo.find_reference(upstream_ref_name.as_ref()) { + let mut upstream_ref = upstream_ref; + if let Ok(upstream_id) = upstream_ref.peel_to_id() { + let head_id_detached = head_id.detach(); + let upstream_id_detached = upstream_id.detach(); + + // Find merge base between HEAD and upstream + let mut graph = gix_revwalk::Graph::new(&repo, None); + if let Ok(merge_base_id) = repo.merge_base_with_graph(head_id_detached, upstream_id_detached, &mut graph) { + let merge_base_detached = merge_base_id.detach(); + + // Count ahead: commits from merge_base to HEAD (exclusive of merge_base) + let mut count = 0usize; + let mut seen = std::collections::HashSet::new(); + let mut queue = vec![head_id_detached]; + while let Some(current) = queue.pop() { + if seen.contains(¤t) { + continue; + } + seen.insert(current); + if current == merge_base_detached { break; } - if !seen.contains(&parent_oid_detached) { - queue.push(parent_oid_detached); + if let Ok(commit) = repo.find_commit(current.clone()) { + for parent_oid in commit.parent_ids() { + let parent_oid_detached = parent_oid.detach(); + if !seen.contains(&parent_oid_detached) { + queue.push(parent_oid_detached); + } + } } + count += 1; } - } - count += 1; - } - ahead = count; + ahead = count; - // Count behind: commits reachable from upstream_id but not from head_id - let mut count = 0usize; - let mut seen = std::collections::HashSet::new(); - let mut queue = vec![upstream_id_detached]; - while let Some(current) = queue.pop() { - if seen.contains(¤t) { - continue; - } - seen.insert(current); - if let Ok(commit) = repo.find_commit(current.clone()) { - for parent_oid in commit.parent_ids() { - let parent_oid_detached = parent_oid.detach(); - if parent_oid_detached == head_id_detached { + // Count behind: commits from merge_base to upstream (exclusive of merge_base) + let mut count = 0usize; + let mut seen = std::collections::HashSet::new(); + let mut queue = vec![upstream_id_detached]; + while let Some(current) = queue.pop() { + if seen.contains(¤t) { + continue; + } + seen.insert(current); + if current == merge_base_detached { break; } - if !seen.contains(&parent_oid_detached) { - queue.push(parent_oid_detached); + if let Ok(commit) = repo.find_commit(current.clone()) { + for parent_oid in commit.parent_ids() { + let parent_oid_detached = parent_oid.detach(); + if !seen.contains(&parent_oid_detached) { + queue.push(parent_oid_detached); + } + } } + count += 1; } + behind = count; } - count += 1; } - behind = count; } } }