From 3f1d3f7b5dd53a981bc2a788113140f0756f31f4 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 13 Nov 2024 23:17:42 -0500 Subject: [PATCH] Include extras and dependency groups in derivation chains --- .../uv-distribution-types/src/derivation.rs | 24 ++++++++++-- crates/uv-resolver/src/resolver/derivation.rs | 38 ++++++++++++------- crates/uv/tests/it/lock.rs | 4 +- crates/uv/tests/it/sync.rs | 6 +-- 4 files changed, 51 insertions(+), 21 deletions(-) diff --git a/crates/uv-distribution-types/src/derivation.rs b/crates/uv-distribution-types/src/derivation.rs index 6084a9f0b056..5e974b52d7fe 100644 --- a/crates/uv-distribution-types/src/derivation.rs +++ b/crates/uv-distribution-types/src/derivation.rs @@ -1,4 +1,4 @@ -use uv_normalize::PackageName; +use uv_normalize::{ExtraName, GroupName, PackageName}; use uv_pep440::Version; use version_ranges::Ranges; @@ -65,6 +65,10 @@ impl IntoIterator for DerivationChain { pub struct DerivationStep { /// The name of the package. pub name: PackageName, + /// The enabled extra of the package, if any. + pub extra: Option, + /// The enabled dependency group of the package, if any. + pub group: Option, /// The version of the package. pub version: Version, /// The constraints applied to the subsequent package in the chain. @@ -73,9 +77,17 @@ pub struct DerivationStep { impl DerivationStep { /// Create a [`DerivationStep`] from a package name and version. - pub fn new(name: PackageName, version: Version, range: Ranges) -> Self { + pub fn new( + name: PackageName, + extra: Option, + group: Option, + version: Version, + range: Ranges, + ) -> Self { Self { name, + extra, + group, version, range, } @@ -84,6 +96,12 @@ impl DerivationStep { impl std::fmt::Display for DerivationStep { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}=={}", self.name, self.version) + if let Some(extra) = &self.extra { + write!(f, "{}[{}]=={}", self.name, extra, self.version) + } else if let Some(group) = &self.group { + write!(f, "{}:{}=={}", self.name, group, self.version) + } else { + write!(f, "{}=={}", self.name, self.version) + } } } diff --git a/crates/uv-resolver/src/resolver/derivation.rs b/crates/uv-resolver/src/resolver/derivation.rs index 1b4fe3dbd8ff..c0218d06dafc 100644 --- a/crates/uv-resolver/src/resolver/derivation.rs +++ b/crates/uv-resolver/src/resolver/derivation.rs @@ -1,11 +1,12 @@ use std::collections::VecDeque; +use petgraph::visit::EdgeRef; use petgraph::Direction; use pubgrub::{Kind, Range, SelectedDependencies, State}; use rustc_hash::FxHashSet; use uv_distribution_types::{ - DerivationChain, DerivationStep, DistRef, Name, Node, Resolution, ResolvedDist, + DerivationChain, DerivationStep, DistRef, Edge, Name, Node, Resolution, ResolvedDist, }; use uv_pep440::Version; @@ -40,11 +41,11 @@ impl DerivationChainBuilder { // Perform a BFS to find the shortest path to the root. let mut queue = VecDeque::new(); - queue.push_back((target, Vec::new())); + queue.push_back((target, None, None, Vec::new())); // TODO(charlie): Consider respecting markers here. let mut seen = FxHashSet::default(); - while let Some((node, mut path)) = queue.pop_front() { + while let Some((node, extra, group, mut path)) = queue.pop_front() { if !seen.insert(node) { continue; } @@ -55,16 +56,25 @@ impl DerivationChainBuilder { return Some(DerivationChain::from_iter(path)); } Node::Dist { dist, .. } => { - path.push(DerivationStep::new( - dist.name().clone(), - dist.version().clone(), - Range::empty(), - )); - for neighbor in resolution - .graph() - .neighbors_directed(node, Direction::Incoming) - { - queue.push_back((neighbor, path.clone())); + for edge in resolution.graph().edges_directed(node, Direction::Incoming) { + let mut path = path.clone(); + path.push(DerivationStep::new( + dist.name().clone(), + extra.clone(), + group.clone(), + dist.version().clone(), + Range::empty(), + )); + let target = edge.source(); + let extra = match edge.weight() { + Edge::Optional(extra, ..) => Some(extra.clone()), + _ => None, + }; + let group = match edge.weight() { + Edge::Dev(group, ..) => Some(group.clone()), + _ => None, + }; + queue.push_back((target, extra, group, path)); } } } @@ -109,6 +119,8 @@ impl DerivationChainBuilder { // Add to the current path. path.push(DerivationStep::new( name.clone(), + p1.extra().cloned(), + p1.dev().cloned(), version.clone(), v2.clone(), )); diff --git a/crates/uv/tests/it/lock.rs b/crates/uv/tests/it/lock.rs index 95f49cf190d6..2385fc44e784 100644 --- a/crates/uv/tests/it/lock.rs +++ b/crates/uv/tests/it/lock.rs @@ -19940,7 +19940,7 @@ fn lock_derivation_chain_extra() -> Result<()> { ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SyntaxError: Missing parentheses in call to 'print'. Did you mean print(...)? - help: `wsgiref` was included because `project==0.1.0` depends on `wsgiref (>=0.1)` + help: `wsgiref` was included because `project[wsgi]==0.1.0` depends on `wsgiref (>=0.1)` "###); Ok(()) @@ -20000,7 +20000,7 @@ fn lock_derivation_chain_group() -> Result<()> { ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SyntaxError: Missing parentheses in call to 'print'. Did you mean print(...)? - help: `wsgiref` was included because `project==0.1.0` depends on `wsgiref (*)` + help: `wsgiref` was included because `project:wsgi==0.1.0` depends on `wsgiref (*)` "###); Ok(()) diff --git a/crates/uv/tests/it/sync.rs b/crates/uv/tests/it/sync.rs index 11e593f6192e..1ff86a627fac 100644 --- a/crates/uv/tests/it/sync.rs +++ b/crates/uv/tests/it/sync.rs @@ -792,7 +792,7 @@ fn sync_build_isolation_extra() -> Result<()> { File "", line 8, in ModuleNotFoundError: No module named 'hatchling' - help: `source-distribution` was included because `project==0.1.0` depends on `source-distribution` + help: `source-distribution` was included because `project[compile]==0.1.0` depends on `source-distribution` "###); // Running `uv sync` with `--all-extras` should also fail. @@ -4398,7 +4398,7 @@ fn sync_derivation_chain_extra() -> Result<()> { ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SyntaxError: Missing parentheses in call to 'print'. Did you mean print(...)? - help: `wsgiref` was included because `project==0.1.0` depends on `wsgiref` + help: `wsgiref` was included because `project[wsgi]==0.1.0` depends on `wsgiref` "###); Ok(()) @@ -4464,7 +4464,7 @@ fn sync_derivation_chain_group() -> Result<()> { ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SyntaxError: Missing parentheses in call to 'print'. Did you mean print(...)? - help: `wsgiref` was included because `project==0.1.0` depends on `wsgiref` + help: `wsgiref` was included because `project:wsgi==0.1.0` depends on `wsgiref` "###); Ok(())