Skip to content

Commit

Permalink
chore(turbo-tasks): Move Invalidator struct from manager.rs to invali…
Browse files Browse the repository at this point in the history
…dation.rs (vercel#69073)

## Why?

`manager.rs` is getting *very* long (>2000 LOC) and hard to read
through. I've been only been making this worse in my stack. This PR is a
contribution to help offset this.

`Invalidator` seems like something that's very weakly coupled with the
rest of the manager, and there's a logical place for it in the
`invalidation` module.

## Testing

```
cargo nextest r -p turbo-tasks -p turbo-tasks-memory
```
  • Loading branch information
bgw authored Sep 4, 2024
1 parent 12cbab2 commit 4b4fd80
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 130 deletions.
129 changes: 128 additions & 1 deletion turbopack/crates/turbo-tasks/src/invalidation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,138 @@ use std::{
fmt::Display,
hash::{Hash, Hasher},
mem::replace,
sync::{Arc, Weak},
};

use anyhow::Result;
use indexmap::{map::Entry, IndexMap, IndexSet};
use serde::{de::Visitor, Deserialize, Serialize};
use tokio::runtime::Handle;

use crate::{magic_any::HasherMut, util::StaticOrArc};
use crate::{
magic_any::HasherMut,
manager::{current_task, with_turbo_tasks},
trace::TraceRawVcs,
util::StaticOrArc,
TaskId, TurboTasksApi,
};

/// Get an [`Invalidator`] that can be used to invalidate the current task
/// based on external events.
pub fn get_invalidator() -> Invalidator {
let handle = Handle::current();
Invalidator {
task: current_task("turbo_tasks::get_invalidator()"),
turbo_tasks: with_turbo_tasks(Arc::downgrade),
handle,
}
}

pub struct Invalidator {
task: TaskId,
turbo_tasks: Weak<dyn TurboTasksApi>,
handle: Handle,
}

impl Invalidator {
pub fn invalidate(self) {
let Invalidator {
task,
turbo_tasks,
handle,
} = self;
let _ = handle.enter();
if let Some(turbo_tasks) = turbo_tasks.upgrade() {
turbo_tasks.invalidate(task);
}
}

pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
let Invalidator {
task,
turbo_tasks,
handle,
} = self;
let _ = handle.enter();
if let Some(turbo_tasks) = turbo_tasks.upgrade() {
turbo_tasks.invalidate_with_reason(
task,
(Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
);
}
}

pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
let Invalidator {
task,
turbo_tasks,
handle,
} = self;
let _ = handle.enter();
if let Some(turbo_tasks) = turbo_tasks.upgrade() {
turbo_tasks
.invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
}
}
}

impl Hash for Invalidator {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.task.hash(state);
}
}

impl PartialEq for Invalidator {
fn eq(&self, other: &Self) -> bool {
self.task == other.task
}
}

impl Eq for Invalidator {}

impl TraceRawVcs for Invalidator {
fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
// nothing here
}
}

impl Serialize for Invalidator {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_newtype_struct("Invalidator", &self.task)
}
}

impl<'de> Deserialize<'de> for Invalidator {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct V;

impl<'de> Visitor<'de> for V {
type Value = Invalidator;

fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "an Invalidator")
}

fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Invalidator {
task: TaskId::deserialize(deserializer)?,
turbo_tasks: with_turbo_tasks(Arc::downgrade),
handle: tokio::runtime::Handle::current(),
})
}
}
deserializer.deserialize_newtype_struct("Invalidator", V)
}
}

pub trait DynamicEqHash {
fn as_any(&self) -> &dyn Any;
Expand Down
12 changes: 6 additions & 6 deletions turbopack/crates/turbo-tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,16 @@ pub use completion::{Completion, Completions};
pub use display::ValueToString;
pub use id::{ExecutionId, FunctionId, TaskId, TraitTypeId, ValueTypeId, TRANSIENT_TASK_BIT};
pub use invalidation::{
DynamicEqHash, InvalidationReason, InvalidationReasonKind, InvalidationReasonSet,
get_invalidator, DynamicEqHash, InvalidationReason, InvalidationReasonKind,
InvalidationReasonSet, Invalidator,
};
pub use join_iter_ext::{JoinIterExt, TryFlatJoinIterExt, TryJoinIterExt};
pub use magic_any::MagicAny;
pub use manager::{
dynamic_call, dynamic_this_call, emit, get_invalidator, mark_finished, mark_stateful,
prevent_gc, run_once, run_once_with_reason, spawn_blocking, spawn_thread, trait_call,
turbo_tasks, CurrentCellRef, Invalidator, ReadConsistency, TaskPersistence, TurboTasks,
TurboTasksApi, TurboTasksBackendApi, TurboTasksBackendApiExt, TurboTasksCallApi, Unused,
UpdateInfo,
dynamic_call, dynamic_this_call, emit, mark_finished, mark_stateful, prevent_gc, run_once,
run_once_with_reason, spawn_blocking, spawn_thread, trait_call, turbo_tasks, CurrentCellRef,
ReadConsistency, TaskPersistence, TurboTasks, TurboTasksApi, TurboTasksBackendApi,
TurboTasksBackendApiExt, TurboTasksCallApi, Unused, UpdateInfo,
};
pub use native_function::{FunctionMeta, NativeFunction};
pub use raw_vc::{CellId, RawVc, ReadRawVcFuture, ResolveTypeError};
Expand Down
125 changes: 2 additions & 123 deletions turbopack/crates/turbo-tasks/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
any::Any,
borrow::Cow,
future::Future,
hash::{BuildHasherDefault, Hash},
hash::BuildHasherDefault,
mem::take,
panic::AssertUnwindSafe,
pin::Pin,
Expand All @@ -18,7 +18,7 @@ use anyhow::{anyhow, Result};
use auto_hash_map::AutoMap;
use futures::FutureExt;
use rustc_hash::FxHasher;
use serde::{de::Visitor, Deserialize, Serialize};
use serde::{Deserialize, Serialize};
use tokio::{runtime::Handle, select, task_local};
use tokio_util::task::TaskTracker;
use tracing::{info_span, instrument, trace_span, Instrument, Level};
Expand Down Expand Up @@ -1513,112 +1513,6 @@ pub(crate) fn current_task(from: &str) -> TaskId {
}
}

pub struct Invalidator {
task: TaskId,
turbo_tasks: Weak<dyn TurboTasksApi>,
handle: Handle,
}

impl Hash for Invalidator {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.task.hash(state);
}
}

impl PartialEq for Invalidator {
fn eq(&self, other: &Self) -> bool {
self.task == other.task
}
}

impl Eq for Invalidator {}

impl Invalidator {
pub fn invalidate(self) {
let Invalidator {
task,
turbo_tasks,
handle,
} = self;
let _ = handle.enter();
if let Some(turbo_tasks) = turbo_tasks.upgrade() {
turbo_tasks.invalidate(task);
}
}

pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
let Invalidator {
task,
turbo_tasks,
handle,
} = self;
let _ = handle.enter();
if let Some(turbo_tasks) = turbo_tasks.upgrade() {
turbo_tasks.invalidate_with_reason(
task,
(Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
);
}
}

pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
let Invalidator {
task,
turbo_tasks,
handle,
} = self;
let _ = handle.enter();
if let Some(turbo_tasks) = turbo_tasks.upgrade() {
turbo_tasks
.invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
}
}
}

impl TraceRawVcs for Invalidator {
fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
// nothing here
}
}

impl Serialize for Invalidator {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_newtype_struct("Invalidator", &self.task)
}
}

impl<'de> Deserialize<'de> for Invalidator {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct V;

impl<'de> Visitor<'de> for V {
type Value = Invalidator;

fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "an Invalidator")
}

fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Invalidator {
task: TaskId::deserialize(deserializer)?,
turbo_tasks: weak_turbo_tasks(),
handle: tokio::runtime::Handle::current(),
})
}
}
deserializer.deserialize_newtype_struct("Invalidator", V)
}
}

pub async fn run_once<T: Send + 'static>(
tt: Arc<dyn TurboTasksApi>,
future: impl Future<Output = Result<T>> + Send + 'static,
Expand Down Expand Up @@ -1704,10 +1598,6 @@ pub fn with_turbo_tasks<T>(func: impl FnOnce(&Arc<dyn TurboTasksApi>) -> T) -> T
TURBO_TASKS.with(|arc| func(arc))
}

pub fn weak_turbo_tasks() -> Weak<dyn TurboTasksApi> {
TURBO_TASKS.with(Arc::downgrade)
}

pub fn with_turbo_tasks_for_testing<T>(
tt: Arc<dyn TurboTasksApi>,
current_task: TaskId,
Expand Down Expand Up @@ -1738,17 +1628,6 @@ pub fn current_task_for_testing() -> TaskId {
CURRENT_GLOBAL_TASK_STATE.with(|ts| ts.read().unwrap().task_id)
}

/// Get an [`Invalidator`] that can be used to invalidate the current task
/// based on external events.
pub fn get_invalidator() -> Invalidator {
let handle = Handle::current();
Invalidator {
task: current_task("turbo_tasks::get_invalidator()"),
turbo_tasks: weak_turbo_tasks(),
handle,
}
}

/// Marks the current task as finished. This excludes it from waiting for
/// strongly consistency.
pub fn mark_finished() {
Expand Down

0 comments on commit 4b4fd80

Please sign in to comment.