Skip to content

Commit

Permalink
jitted more components in power module
Browse files Browse the repository at this point in the history
  • Loading branch information
kkarrancsu committed Jun 7, 2023
1 parent dd89920 commit 78e3931
Show file tree
Hide file tree
Showing 4 changed files with 2,387 additions and 1,273 deletions.
91 changes: 25 additions & 66 deletions mechafil_jax/power.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Union
import numbers
from functools import partial

import jax
import jax.numpy as jnp
Expand All @@ -12,32 +13,14 @@
NOTE:
- This module does not support tunable QAP mode, only basic.
"""

# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
def scalar_or_vector_to_vector(
input_x: Union[jnp.ndarray, NDArray, float], expected_len: int, err_msg: str = None
) -> jnp.ndarray:
if isinstance(input_x, numbers.Number):
return jnp.ones(expected_len) * input_x
elif isinstance(input_x, np.ndarray):
return jnp.array(input_x)
else:
err_msg_out = (
"vector input does not match expected length!"
if err_msg is None
else err_msg
)
assert len(input_x) == expected_len, err_msg_out
return input_x


# --------------------------------------------------------------------------------------
# QA Multiplier functions
# --------------------------------------------------------------------------------------
# NOTE: I tried adding in jax.jit here, but it seemed to make it slower?
# we need a more systematic way to benchmark all of this
def compute_qa_factor(
fil_plus_rate: Union[jnp.array, NDArray, float],
fil_plus_rate: Union[jnp.array, NDArray],
fil_plus_m: float = 10.0,
duration_m: Callable = None,
duration: int = None,
Expand All @@ -51,55 +34,44 @@ def compute_qa_factor(
# --------------------------------------------------------------------------------------
# Onboardings
# --------------------------------------------------------------------------------------
def forecast_rb_daily_onboardings(
rb_onboard_power: Union[jnp.array, NDArray, float], forecast_lenght: int
) -> jnp.array:
rb_onboarded_power_vec = scalar_or_vector_to_vector(
rb_onboard_power,
forecast_lenght,
err_msg="If rb_onboard_power is provided as a vector, it must be the same length as the forecast_length",
)
return rb_onboarded_power_vec


@jax.jit
def forecast_qa_daily_onboardings(
rb_onboard_power: Union[jnp.array, NDArray, float],
fil_plus_rate: Union[jnp.array, NDArray, float],
forecast_lenght: int,
rb_onboard_power: Union[jnp.array, NDArray],
fil_plus_rate: Union[jnp.array, NDArray],
fil_plus_m: float = 10.0,
duration_m: Callable = None,
duration: int = None,
) -> jnp.array:
# If duration_m is not provided, qa_factor = 1.0 + 9.0 * fil_plus_rate
qa_factor = compute_qa_factor(fil_plus_rate, fil_plus_m, duration_m, duration)
qa_onboard_power = qa_factor * rb_onboard_power
qa_onboard_power_vec = scalar_or_vector_to_vector(
qa_onboard_power,
forecast_lenght,
err_msg="If qa_onboard_power is provided as a vector, it must be the same length as the forecast_length",
)
qa_onboard_power_vec = qa_factor * rb_onboard_power
return qa_onboard_power_vec

# --------------------------------------------------------------------------------------
# Renewals
# --------------------------------------------------------------------------------------
@jax.jit
def basic_scalar_renewed_power(day_sched_expire_power, renewal_rate):
return day_sched_expire_power * renewal_rate

# --------------------------------------------------------------------------------------
# Scheduled expirations
# --------------------------------------------------------------------------------------
@jax.jit
def have_known_se_sector_info(arggs):
known_scheduled_expire_vec, day_i = arggs
return known_scheduled_expire_vec[day_i]

@jax.jit
def dont_have_known_se_sector_info(arggs):
return 0.0

@jax.jit
def have_modeled_sector_expiration_info(arggs):
day_onboard_vec, day_renewed_vec, day_i, duration = arggs
return day_onboard_vec[day_i - duration] + day_renewed_vec[day_i - duration]

@jax.jit
def dont_have_modeled_sector_expiration_info(arggs):
return 0.0

Expand Down Expand Up @@ -136,46 +108,33 @@ def compute_se_and_rr(carry, x):

return (day_rb_renewed_power_vec, rb_known_sched_expire, day_rb_onboarded_power, renewal_rate_vec, day_i+1, duration), day_se_power

@partial(jax.jit, static_argnums=(7,8,))
def forecast_power_stats(
rb_power_zero: float,
qa_power_zero: float,
rb_onboard_power: Union[jnp.array, NDArray, float],
day_rb_onboarded_power: Union[jnp.array, NDArray, float],
rb_known_scheduled_expire_vec: Union[jnp.array, NDArray],
qa_known_scheduled_expire_vec: Union[jnp.array, NDArray],
renewal_rate: Union[jnp.array, NDArray, float],
fil_plus_rate: Union[jnp.array, NDArray, float],
renewal_rate_vec: Union[jnp.array, NDArray],
fil_plus_rate_vec: Union[jnp.array, NDArray],
duration: int,
forecast_length: int,
fil_plus_m: float = 10,
**kwargs # a noop in this port, but here for backwards compatibility
):
# force duration to be an integer
duration = int(duration)

renewal_rate_vec = scalar_or_vector_to_vector(
renewal_rate,
forecast_length,
err_msg="If renewal_rate is provided as a vector, it must be the same length as the forecast_length",
)

day_rb_onboarded_power = forecast_rb_daily_onboardings(
rb_onboard_power, forecast_length
)
fil_plus_m: float = 10.0,
):
total_rb_onboarded_power = day_rb_onboarded_power.cumsum()

day_qa_onboarded_power = forecast_qa_daily_onboardings(
rb_onboard_power,
fil_plus_rate,
forecast_length,
day_rb_onboarded_power,
fil_plus_rate_vec,
fil_plus_m,
duration_m=None,
duration=duration,
)
total_qa_onboarded_power = day_qa_onboarded_power.cumsum()

# compute SE & RR for both RBP & QAP
day_rb_renewed_power_vec = jnp.zeros(len(day_rb_onboarded_power))
day_qa_renewed_power_vec = jnp.zeros(len(day_qa_onboarded_power))
day_rb_renewed_power_vec = jnp.zeros(forecast_length)
day_qa_renewed_power_vec = jnp.zeros(forecast_length)

init_in = (day_rb_renewed_power_vec, rb_known_scheduled_expire_vec, day_rb_onboarded_power, renewal_rate_vec, 0, duration)
ret, day_rb_scheduled_expire_power = lax.scan(compute_se_and_rr, init_in, None, length=forecast_length)
Expand All @@ -192,15 +151,15 @@ def forecast_power_stats(
total_qa_renewed_power = day_qa_renewed_power.cumsum()

# Total RB power
rb_power_zero_vec = np.ones(forecast_length) * rb_power_zero
rb_power_zero_vec = jnp.ones(forecast_length) * rb_power_zero
rb_total_power = (
rb_power_zero_vec
+ total_rb_onboarded_power
- total_rb_scheduled_expire_power
+ total_rb_renewed_power
)
# Total QA power
qa_power_zero_vec = np.ones(forecast_length) * qa_power_zero
qa_power_zero_vec = jnp.ones(forecast_length) * qa_power_zero
qa_total_power = (
qa_power_zero_vec
+ total_qa_onboarded_power
Expand Down
Loading

0 comments on commit 78e3931

Please sign in to comment.