Skip to content

Commit

Permalink
misc changes (#9)
Browse files Browse the repository at this point in the history
* misc changes

* format
  • Loading branch information
codekansas authored Mar 5, 2024
1 parent 892d4dd commit debac12
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ all: format static-checks test

format:
@black $(py-files)
@ruff --fix $(py-files)
@ruff format $(py-files)
.PHONY: format

static-checks:
Expand Down
6 changes: 2 additions & 4 deletions examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Config(xax.Config):

class MnistClassification(xax.Task[Config]):
def get_model(self) -> Model:
return Model(self.prng_key)
return Model(self.prng_key())

def get_optimizer(self) -> optax.GradientTransformation:
return optax.adam(1e-3)
Expand Down Expand Up @@ -100,6 +100,4 @@ def get_dataset(self, phase: xax.Phase) -> MNIST:

if __name__ == "__main__":
# python -m examples.mnist
config = Config(batch_size=16)
config.train_dl.num_workers = 1
MnistClassification.launch(config)
MnistClassification.launch(Config(batch_size=16))
1 change: 1 addition & 0 deletions xax/core/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Triton:
@dataclass
class Experiment:
default_random_seed: int = field(1337, help="The default random seed to use")
max_workers: int = field(32, help="Maximum number of workers to use")


@dataclass
Expand Down
13 changes: 5 additions & 8 deletions xax/task/mixins/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class DataloaderErrorConfig:

@dataclass
class DataloaderConfig:
batch_size: int = field(MISSING, help="Size of each batch")
num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
prefetch_factor: int = field(2, help="Number of items to pre-fetch on each worker")
error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
Expand All @@ -49,11 +48,11 @@ class DataloadersConfig(ProcessConfig, BaseConfig):
batch_size: int = field(MISSING, help="Size of each batch")
raise_dataloader_errors: bool = field(False, help="If set, raise dataloader errors inside the worker processes")
train_dl: DataloaderConfig = field(
DataloaderConfig(batch_size=II("batch_size")),
DataloaderConfig(num_workers=II("mlfab.num_workers:-1")),
help="Train dataloader config",
)
valid_dl: DataloaderConfig = field(
DataloaderConfig(batch_size=II("batch_size"), num_workers=1),
DataloaderConfig(num_workers=1),
help="Valid dataloader config",
)
debug_dataloader: bool = field(False, help="Debug dataloaders")
Expand All @@ -64,9 +63,7 @@ class DataloadersConfig(ProcessConfig, BaseConfig):

class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
def __init__(self, config: Config) -> None:
if is_missing(config, "batch_size") and (
is_missing(config.train_dl, "batch_size") or is_missing(config.valid_dl, "batch_size")
):
if is_missing(config, "batch_size"):
config.batch_size = self.get_batch_size()

super().__init__(config)
Expand Down Expand Up @@ -120,10 +117,10 @@ def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader

return Dataloader(
dataset=dataset,
batch_size=cfg.batch_size,
batch_size=self.config.batch_size,
num_workers=0 if debugging else cfg.num_workers,
prefetch_factor=cfg.prefetch_factor,
ctx=self.multiprocessing_context,
mp_manager=self.multiprocessing_manager,
dataloader_worker_init_fn=self.dataloader_worker_init_fn,
collate_worker_init_fn=self.collate_worker_init_fn,
item_callback=self.dataloader_item_callback,
Expand Down
4 changes: 4 additions & 0 deletions xax/task/mixins/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(self, config: Config) -> None:
def multiprocessing_context(self) -> BaseContext:
return self._mp_ctx

@property
def multiprocessing_manager(self) -> SyncManager:
return self._mp_manager

def on_training_end(self, state: State) -> State:
state = super().on_training_end(state)

Expand Down
7 changes: 1 addition & 6 deletions xax/task/mixins/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ class TrainMixin(
_training_over_flag: bool
_last_printed_remaining_time: float
_step_kind: StepKind
_prng_key: jnp.ndarray

def __init__(self, config: Config) -> None:
super().__init__(config)
Expand All @@ -192,12 +191,8 @@ def __init__(self, config: Config) -> None:
# The kind of step that was specified in the config.
self._step_kind = cast_step_kind(self.config.step_kind)

# Defines a PRNG key for the task.
self._prng_key = jax.random.PRNGKey(self.config.random_seed)

@property
def prng_key(self) -> jnp.ndarray:
return self._prng_key
return jax.random.PRNGKey(self.config.random_seed)

def on_step_end(self, state: State) -> State:
state = super().on_step_end(state)
Expand Down
45 changes: 44 additions & 1 deletion xax/utils/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import itertools
import logging
import math
import os
import random
import re
Expand All @@ -30,7 +31,7 @@
from jaxtyping import Array
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf

from xax.core.conf import get_data_dir, get_pretrained_models_dir
from xax.core.conf import get_data_dir, get_pretrained_models_dir, load_user_config
from xax.core.state import State
from xax.utils.text import colored

Expand Down Expand Up @@ -756,3 +757,45 @@ def get_state_dict_prefix(
if regexp is not None:
ckpt = {k: v for k, v in ckpt.items() if regexp.match(k)}
return ckpt


def split_n_items_across_workers(n: int, worker_id: int, num_workers: int) -> tuple[int, int]:
"""Computes offsets for splitting N items across K workers.
This returns the start and end indices for the items to be processed by the
given worker. The end index is exclusive.
Args:
n: The number of items to process.
worker_id: The ID of the current worker.
num_workers: The total number of workers.
Returns:
The start and end index for the items in the current worker.
"""
assert n >= num_workers, f"n ({n}) must be >= num_workers ({num_workers})"
assert 0 <= worker_id < num_workers, f"worker_id ({worker_id}) must be >= 0 and < num_workers ({num_workers})"

# The number of items to process per worker.
items_per_worker = math.ceil(n / num_workers)

# The start and end indices for the items to process.
start = worker_id * items_per_worker
end = min(start + items_per_worker, n)

return start, end


def num_workers(default: int) -> int:
max_workers = load_user_config().experiment.max_workers
if hasattr(os, "sched_getaffinity"):
try:
return min(len(os.sched_getaffinity(0)), max_workers)
except Exception:
pass
if (cpu_count := os.cpu_count()) is not None:
return min(cpu_count, max_workers)
return min(default, max_workers)


OmegaConf.register_new_resolver("mlfab.num_workers", num_workers, replace=True)
29 changes: 29 additions & 0 deletions xax/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import socket
import sys

from omegaconf import OmegaConf

from xax.core.conf import load_user_config
from xax.utils.text import Color, color_parts, colored

Expand Down Expand Up @@ -175,6 +177,33 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
logging.getLogger("torch").setLevel(logging.WARNING)


def get_unused_port(default: int | None = None) -> int:
"""Returns an unused port number on the local machine.
Args:
default: A default port to try before trying other ports.
Returns:
A port number which is currently unused
"""
if default is not None:
sock = socket.socket()
try:
sock.bind(("", default))
return default
except OSError:
pass
finally:
sock.close()

sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


OmegaConf.register_new_resolver("mlfab.unused_port", get_unused_port, replace=True)


def port_is_busy(port: int) -> int:
"""Checks whether a port is busy.
Expand Down

0 comments on commit debac12

Please sign in to comment.