Skip to content

Commit

Permalink
feat: add channel priority to solve task and expose to python solve (c…
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminLowry authored Apr 25, 2024
1 parent f3a0ce3 commit 926a876
Show file tree
Hide file tree
Showing 17 changed files with 333 additions and 41 deletions.
3 changes: 2 additions & 1 deletion crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use rattler_repodata_gateway::fetch::{
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{
libsolv_c::{self},
resolvo, SolverImpl, SolverTask,
resolvo, ChannelPriority, SolverImpl, SolverTask,
};
use reqwest::Client;
use std::sync::Arc;
Expand Down Expand Up @@ -235,6 +235,7 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
specs,
pinned_packages: Vec::new(),
timeout: opt.timeout.map(Duration::from_millis),
channel_priority: ChannelPriority::Strict,
};

// Next, use a solver to solve this specific problem. This provides us with all the operations
Expand Down
4 changes: 3 additions & 1 deletion crates/rattler_solve/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingM
use rattler_conda_types::ParseStrictness::Strict;
use rattler_conda_types::{Channel, ChannelConfig, MatchSpec};
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{SolverImpl, SolverTask};
use rattler_solve::{ChannelPriority, SolverImpl, SolverTask};

fn conda_json_path() -> String {
format!(
Expand Down Expand Up @@ -69,6 +69,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {
virtual_packages: vec![],
specs: specs.clone(),
timeout: None,
channel_priority: ChannelPriority::Strict,
}))
.unwrap()
});
Expand All @@ -85,6 +86,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {
virtual_packages: vec![],
specs: specs.clone(),
timeout: None,
channel_priority: ChannelPriority::Strict,
}))
.unwrap()
});
Expand Down
19 changes: 19 additions & 0 deletions crates/rattler_solve/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ impl fmt::Display for SolveError {
}
}

/// Represents the channel priority option to use during solves.
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum ChannelPriority {
/// The channel that the package is first found in will be used as the only channel
/// for that package.
Strict,

// Conda also has "Flexible" as an option, where packages present in multiple channels
// are only taken from lower-priority channels when this prevents unsatisfiable environment
// errors, but this would need implementation in the solvers.
// Flexible,
/// Packages can be retrieved from any channel as package version takes precedence.
Disabled,
}

/// Represents a dependency resolution task, to be solved by one of the backends (currently only
/// libsolv is supported)
pub struct SolverTask<TAvailablePackagesIterator> {
Expand Down Expand Up @@ -102,6 +117,10 @@ pub struct SolverTask<TAvailablePackagesIterator> {

/// The timeout after which the solver should stop
pub timeout: Option<std::time::Duration>,

/// The channel priority to solve with, either [`ChannelPriority::Strict`] or
/// [`ChannelPriority::Disabled`]
pub channel_priority: ChannelPriority,
}

/// A representation of a collection of [`RepoDataRecord`] usable by a [`SolverImpl`]
Expand Down
8 changes: 6 additions & 2 deletions crates/rattler_solve/src/libsolv_c/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,14 @@ pub fn cache_repodata(_url: String, _data: &[RepoDataRecord]) -> LibcByteSlice {
/// Note: this function relies on primitives that are only available on unix-like operating systems,
/// and will panic if called from another platform (e.g. Windows)
#[cfg(target_family = "unix")]
pub fn cache_repodata(url: String, data: &[RepoDataRecord]) -> LibcByteSlice {
pub fn cache_repodata(
url: String,
data: &[RepoDataRecord],
channel_priority: Option<i32>,
) -> LibcByteSlice {
// Add repodata to a new pool + repo
let pool = Pool::default();
let repo = Repo::new(&pool, url);
let repo = Repo::new(&pool, url, channel_priority.unwrap_or(0));
add_repodata_records(&pool, &repo, data);

// Export repo to .solv in memory
Expand Down
65 changes: 56 additions & 9 deletions crates/rattler_solve/src/libsolv_c/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! Provides an solver implementation based on the [`rattler_libsolv_c`] crate.
use crate::{IntoRepoData, SolverRepoData};
use crate::{ChannelPriority, IntoRepoData, SolverRepoData};
use crate::{SolveError, SolverTask};
pub use input::cache_repodata;
use input::{add_repodata_records, add_solv_file, add_virtual_packages};
pub use libc_byte_slice::LibcByteSlice;
use output::get_required_packages;
use rattler_conda_types::RepoDataRecord;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::ffi::CString;
use std::mem::ManuallyDrop;
use wrapper::{
Expand Down Expand Up @@ -106,8 +106,47 @@ impl super::SolverImpl for Solver {
});
pool.set_debug_level(Verbosity::Low);

let repodatas: Vec<Self::RepoData<'_>> = task
.available_packages
.into_iter()
.map(IntoRepoData::into)
.collect();

// Determine the channel priority for each channel in the repodata in the order in which
// the repodatas are passed, where the first channel will have the highest priority value
// and each successive channel will descend in priority value. If not strict, the highest
// priority value will be 0 and the channel priority map will not be populated as it will
// not be used.
let mut highest_priority: i32 = 0;
let channel_priority: HashMap<String, i32> =
if task.channel_priority == ChannelPriority::Strict {
let mut seen_channels = HashSet::new();
let mut channel_order: Vec<String> = Vec::new();
for channel in repodatas
.iter()
.filter(|&r| !r.records.is_empty())
.map(|r| r.records[0].channel.clone())
{
if !seen_channels.contains(&channel) {
channel_order.push(channel.clone());
seen_channels.insert(channel);
}
}
let mut channel_priority = HashMap::new();
for (index, channel) in channel_order.iter().enumerate() {
let reverse_index = channel_order.len() - index;
if index == 0 {
highest_priority = reverse_index as i32;
}
channel_priority.insert(channel.clone(), reverse_index as i32);
}
channel_priority
} else {
HashMap::new()
};

// Add virtual packages
let repo = Repo::new(&pool, "virtual_packages");
let repo = Repo::new(&pool, "virtual_packages", highest_priority);
add_virtual_packages(&pool, &repo, &task.virtual_packages);

// Mark the virtual packages as installed.
Expand All @@ -116,15 +155,19 @@ impl super::SolverImpl for Solver {
// Create repos for all channel + platform combinations
let mut repo_mapping = HashMap::new();
let mut all_repodata_records = Vec::new();
for repodata in task.available_packages.into_iter().map(IntoRepoData::into) {
for repodata in repodatas.iter() {
if repodata.records.is_empty() {
continue;
}

let channel_name = &repodata.records[0].channel;

// We dont want to drop the Repo, its stored in the pool anyway.
let repo = ManuallyDrop::new(Repo::new(&pool, channel_name));
let priority: i32 = if task.channel_priority == ChannelPriority::Strict {
*channel_priority.get(channel_name).unwrap()
} else {
0
};
let repo = ManuallyDrop::new(Repo::new(&pool, channel_name, priority));

if let Some(solv_file) = repodata.solv_file {
add_solv_file(&pool, &repo, solv_file);
Expand All @@ -134,19 +177,19 @@ impl super::SolverImpl for Solver {

// Keep our own info about repodata_records
repo_mapping.insert(repo.id(), repo_mapping.len());
all_repodata_records.push(repodata.records);
all_repodata_records.push(repodata.records.clone());
}

// Create a special pool for records that are already installed or locked.
let repo = Repo::new(&pool, "locked");
let repo = Repo::new(&pool, "locked", highest_priority);
let installed_solvables = add_repodata_records(&pool, &repo, &task.locked_packages);

// Also add the installed records to the repodata
repo_mapping.insert(repo.id(), repo_mapping.len());
all_repodata_records.push(task.locked_packages.iter().collect());

// Create a special pool for records that are pinned and cannot be changed.
let repo = Repo::new(&pool, "pinned");
let repo = Repo::new(&pool, "pinned", highest_priority);
let pinned_solvables = add_repodata_records(&pool, &repo, &task.pinned_packages);

// Also add the installed records to the repodata
Expand Down Expand Up @@ -179,6 +222,10 @@ impl super::SolverImpl for Solver {
let mut solver = pool.create_solver();
solver.set_flag(SolverFlag::allow_uninstall(), true);
solver.set_flag(SolverFlag::allow_downgrade(), true);
solver.set_flag(
SolverFlag::strict_channel_priority(),
task.channel_priority == ChannelPriority::Strict,
);

let transaction = solver.solve(&mut goal).map_err(SolveError::Unsolvable)?;

Expand Down
8 changes: 7 additions & 1 deletion crates/rattler_solve/src/libsolv_c/wrapper/flags.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::ffi::{SOLVER_FLAG_ALLOW_DOWNGRADE, SOLVER_FLAG_ALLOW_UNINSTALL};
use super::ffi::{
SOLVER_FLAG_ALLOW_DOWNGRADE, SOLVER_FLAG_ALLOW_UNINSTALL, SOLVER_FLAG_STRICT_REPO_PRIORITY,
};

#[repr(transparent)]
pub struct SolverFlag(u32);
Expand All @@ -12,6 +14,10 @@ impl SolverFlag {
SolverFlag(SOLVER_FLAG_ALLOW_DOWNGRADE)
}

pub fn strict_channel_priority() -> SolverFlag {
SolverFlag(SOLVER_FLAG_STRICT_REPO_PRIORITY)
}

pub fn inner(self) -> i32 {
self.0 as i32
}
Expand Down
11 changes: 6 additions & 5 deletions crates/rattler_solve/src/libsolv_c/wrapper/repo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ impl<'pool> Drop for Repo<'pool> {

impl<'pool> Repo<'pool> {
/// Constructs a repo in the provided pool, associated to the given url
pub fn new(pool: &Pool, url: impl AsRef<str>) -> Repo<'_> {
pub fn new(pool: &Pool, url: impl AsRef<str>, priority: i32) -> Repo<'_> {
let c_url = c_string(url);

unsafe {
let repo_ptr = ffi::repo_create(pool.raw_ptr(), c_url.as_ptr());
let non_null_ptr = NonNull::new(repo_ptr).expect("repo ptr was null");
let mut non_null_ptr = NonNull::new(repo_ptr).expect("repo ptr was null");
non_null_ptr.as_mut().priority = priority;
Repo(non_null_ptr, PhantomData)
}
}
Expand Down Expand Up @@ -120,7 +121,7 @@ mod tests {
#[test]
fn test_repo_creation() {
let pool = Pool::default();
let mut _repo = Repo::new(&pool, "conda-forge");
let mut _repo = Repo::new(&pool, "conda-forge", 0);
}

#[test]
Expand All @@ -132,7 +133,7 @@ mod tests {
{
// Create a pool and a repo
let pool = Pool::default();
let repo = Repo::new(&pool, "conda-forge");
let repo = Repo::new(&pool, "conda-forge", 0);

// Add a solvable with a particular name
let solvable_id = repo.add_solvable();
Expand All @@ -148,7 +149,7 @@ mod tests {

// Create a clean pool and repo
let pool = Pool::default();
let repo = Repo::new(&pool, "conda-forge");
let repo = Repo::new(&pool, "conda-forge", 0);

// Open and read the .solv file
let mode = c_string("rb");
Expand Down
14 changes: 9 additions & 5 deletions crates/rattler_solve/src/resolvo/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Provides an solver implementation based on the [`resolvo`] crate.
use crate::{IntoRepoData, SolveError, SolverRepoData, SolverTask};
use crate::{ChannelPriority, IntoRepoData, SolveError, SolverRepoData, SolverTask};
use rattler_conda_types::package::ArchiveType;
use rattler_conda_types::{
GenericVirtualPackage, MatchSpec, NamelessMatchSpec, PackageRecord, ParseMatchSpecError,
Expand Down Expand Up @@ -177,6 +177,7 @@ impl<'a> CondaDependencyProvider<'a> {
virtual_packages: &'a [GenericVirtualPackage],
match_specs: &[MatchSpec],
stop_time: Option<std::time::SystemTime>,
channel_priority: ChannelPriority,
) -> Self {
let pool = Rc::new(Pool::default());
let mut records: HashMap<NameId, Candidates> = HashMap::default();
Expand Down Expand Up @@ -283,10 +284,12 @@ impl<'a> CondaDependencyProvider<'a> {
}

// Enforce channel priority
// This functions makes the assumtion that the records are given in order of the channels.
if let Some(first_channel) = package_name_found_in_channel
.get(&record.package_record.name.as_normalized().to_string())
{
// This function makes the assumption that the records are given in order of the channels.
if let (Some(first_channel), ChannelPriority::Strict) = (
package_name_found_in_channel
.get(&record.package_record.name.as_normalized().to_string()),
channel_priority,
) {
// Add the record to the excluded list when it is from a different channel.
if first_channel != &&record.channel {
tracing::debug!(
Expand Down Expand Up @@ -445,6 +448,7 @@ impl super::SolverImpl for Solver {
&task.virtual_packages,
task.specs.clone().as_ref(),
stop_time,
task.channel_priority,
);
let pool = provider.pool.clone();

Expand Down
Loading

0 comments on commit 926a876

Please sign in to comment.