Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL algorithm implementation #19

Merged
merged 16 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added shorter_pulses to algorithm_kwargs and changed callback to stop…
… train early, removed update_solver(), use of var_time = True supported, other minor fixes
  • Loading branch information
LegionAtol committed Aug 24, 2024
commit e66da0ffd63d65fb9fdd4b0d56071c47440c0d7c
77 changes: 34 additions & 43 deletions src/qutip_qoc/_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def __init__(
super(_RL,self).__init__()

self._Hd_lst, self._Hc_lst = [], []
if not isinstance(objectives, list):
objectives = [objectives]
for objective in objectives:
# extract drift and control Hamiltonians from the objective
self._Hd_lst.append(objective.H[0])
Expand All @@ -61,15 +59,15 @@ def __init__(
self._H_lst.append([Hc, lambda t, args: self.pulse(t, self.args, i+1)])
self._H = qt.QobjEvo(self._H_lst, self.args)

self._control_parameters = control_parameters
# extract bounds for _control_parameters
# extract bounds for control_parameters
bounds = []
for key in control_parameters.keys():
bounds.append(control_parameters[key].get("bounds"))
self.lbound = [b[0][0] for b in bounds]
self.ubound = [b[0][1] for b in bounds]

self._alg_kwargs = alg_kwargs
self.shorter_pulses = self._alg_kwargs.get("shorter_pulses", False) # lengthen the training to look for pulses of shorter duration, therefore episodes with fewer steps
flowerthrower marked this conversation as resolved.
Show resolved Hide resolved

self._initial = objectives[0].initial
self._target = objectives[0].target
Expand All @@ -82,7 +80,7 @@ def __init__(
start_local_time = time.localtime(), # initial optimization time
n_iters = 0, # Number of iterations(episodes) until convergence
iter_seconds = [], # list containing the time taken for each iteration(episode) of the optimization
var_time = False, # Whether the optimization was performed with variable time
var_time = True, # Whether the optimization was performed with variable time
)

#for the reward
Expand Down Expand Up @@ -123,21 +121,14 @@ def __init__(
self.action_space = spaces.Box(low=-1, high=1, shape=(len(self._Hc_lst[0]),), dtype=np.float32) # Continuous action space from -1 to +1, as suggested from gym
self.observation_space = spaces.Box(low=-1, high=1, shape=obs_shape, dtype=np.float32) # Observation space

def update_solver(self):
"""
Update the solver and fidelity type based on the problem setup.
Chooses the appropriate solver (Schrödinger or master equation) and
prepares for infidelity calculation.
"""
# create the solver
if self._Hd_lst[0].issuper:
self._fid_type = self._alg_kwargs.get("fid_type", "TRACEDIFF")
self._solver = qt.MESolver(H=self._H, options=self._integrator_kwargs)
else:
self._fid_type = self._alg_kwargs.get("fid_type", "PSU")
self._solver = qt.SESolver(H=self._H, options=self._integrator_kwargs)

self.infidelity = self._infid

def pulse(self, t, args, idx):
flowerthrower marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the control pulse value at time t for a given index.
Expand All @@ -153,7 +144,8 @@ def save_episode_info(self):
"final_infidelity": self._result.infidelity,
"terminated": self.terminated,
"truncated": self.truncated,
"steps_used": self.current_step
"steps_used": self.current_step,
"elapsed_time": time.mktime(time.localtime())
}
self.episode_info.append(episode_data)

Expand Down Expand Up @@ -189,10 +181,8 @@ def step(self, action):

for i, value in enumerate(alphas):
self.args[f"alpha{i+1}"] = value
self._H = qt.QobjEvo(self._H_lst, self.args)

self.update_solver() # _H has changed
infidelity = self.infidelity()
infidelity = self._infid()
flowerthrower marked this conversation as resolved.
Show resolved Hide resolved

self.current_step += 1
self.temp_actions.append(alphas)
Expand Down Expand Up @@ -220,7 +210,7 @@ def reset(self, seed=None):
"""
self.save_episode_info()

time_diff = time.mktime(time.localtime()) - time.mktime(self._result.start_local_time)
time_diff = self.episode_info[-1]["elapsed_time"] - (self.episode_info[-2]["elapsed_time"] if len(self.episode_info) > 1 else time.mktime(self._result.start_local_time))
self._result.iter_seconds.append(time_diff)
self.current_step = 0 # Reset the step counter
self.current_episode += 1 # Increment episode counter
Expand All @@ -238,7 +228,7 @@ def result(self):
"""
self._result.end_local_time = time.localtime()
self._result.n_iters = len(self._result.iter_seconds)
self._result.optimized_params = self.actions.copy()
self._result.optimized_params = self.actions.copy() + [self._result.total_seconds] # If var_time is True, the last parameter is the evolution time
self._result._optimized_controls = self.actions.copy()
self._result._final_states = (self._result._final_states if self._result._final_states is not None else []) + [self.state]
self._result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.start_local_time) # Convert to a string
Expand Down Expand Up @@ -274,40 +264,41 @@ def __init__(self, verbose: int = 0):

def _on_step(self) -> bool:
"""
This method is required by the BaseCallback class. We use it only to stop the training.
This method is required by the BaseCallback class. We use it to stop the training.
- Stop training if the maximum number of episodes is reached.
- Stop training if it finds an episode with infidelity <= than target infidelity
"""
env = self.training_env.envs[0].unwrapped

# Check if we need to stop training
if self.stop_train:
return False # Stop training
elif env.current_episode >= env.max_episodes:
env._result.message = f"Reached {env.max_episodes} episodes, stopping training."
return False # Stop training
elif (env._result.infidelity <= env._fid_err_targ) and not(env.shorter_pulses):
env._result.message = f"Stop training because an episode with infidelity <= target infidelity was found"
return False # Stop training
return True # Continue training

def _on_rollout_start(self) -> None:
"""
This method is called before the rollout starts (before collecting new samples).
Checks:
- If all of the last 100 episodes have infidelity below the target and use the same number of steps, stop training.
- Stop training if the maximum number of episodes is reached.
"""
#could be moved to on_step

env = self.training_env.envs[0].unwrapped
max_episodes = env.max_episodes
fid_err_targ = env._fid_err_targ

if len(env.episode_info) >= 100:
last_100_episodes = env.episode_info[-100:]

min_steps = min(info['steps_used'] for info in last_100_episodes)
steps_condition = all(ep['steps_used'] == min_steps for ep in last_100_episodes)
infid_condition = all(ep['final_infidelity'] <= fid_err_targ for ep in last_100_episodes)

if steps_condition and infid_condition:
env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid."
#print(f"Stopping training as no episode in the last 100 used fewer steps and infidelity was below target infid.")
self.stop_train = True # Stop training
#print([ep['steps_used'] for ep in last_100_episodes])
#print([ep['final_infidelity'] for ep in last_100_episodes])

# Check max episodes condition
if env.current_episode >= max_episodes:
env._result.message = f"Reached {max_episodes} episodes, stopping training."
#print(f"Reached {max_episodes} episodes, stopping training.")
self.stop_train = True # Stop training
#Only if specified in alg_kwargs, the algorithm will search for shorter pulses, resulting in episodes with fewer steps.
if env.shorter_pulses:
if len(env.episode_info) >= 100:
last_100_episodes = env.episode_info[-100:]

min_steps = min(info['steps_used'] for info in last_100_episodes)
steps_condition = all(ep['steps_used'] == min_steps for ep in last_100_episodes)
infid_condition = all(ep['final_infidelity'] <= env._fid_err_targ for ep in last_100_episodes)

if steps_condition and infid_condition:
env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid."
self.stop_train = True # Stop training
12 changes: 8 additions & 4 deletions src/qutip_qoc/pulse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def optimize_pulses(
Dictionary of options for the control pulse optimization.
The keys of this dict must be a unique string identifier for each control Hamiltonian / function.
For the GOAT and JOPT algorithms, the dict may optionally also contain the key "__time__".
For RL you don't need to specify the guess.
For each control function it must specify:

control_id : dict
Expand All @@ -46,6 +45,7 @@ def optimize_pulses(
where ``n`` is the number of independent variables.

- bounds : sequence, optional
For RL you don't need to specify the guess.
flowerthrower marked this conversation as resolved.
Show resolved Hide resolved
Sequence of ``(min, max)`` pairs for each element in
`guess`. None is used to specify no bound.

Expand Down Expand Up @@ -84,6 +84,11 @@ def optimize_pulses(
Global steps default to 0 (no global optimization).
Can be overridden by specifying in minimizer_kwargs.

- shorter_pulses : bool, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically what the __time__ control parameter is used for.
Of course, RL doesn't really take a 'guess' (as with the other parameters), but we could use the same structure to only look for shorter pulses if the __time__ bounds are specified.

If set to True, allows the algorithm to search for shorter control
pulses that can achieve the desired fidelity target using fewer steps.
By default, it is set to False, only attempting to reach the target infidelity.

Algorithm specific keywords for GRAPE,CRAB can be found in
:func:`qutip_qtrl.pulseoptim.optimize_pulse`.

Expand Down Expand Up @@ -154,7 +159,7 @@ def optimize_pulses(
# extract guess and bounds for the control pulses
x0, bounds = [], []
for key in control_parameters.keys():
x0.append(control_parameters[key].get("guess")) # TODO: for now only consider bounds
x0.append(control_parameters[key].get("guess"))
bounds.append(control_parameters[key].get("bounds"))
try: # GRAPE, CRAB format
lbound = [b[0][0] for b in bounds]
Expand Down Expand Up @@ -351,8 +356,7 @@ def optimize_pulses(

qtrl_optimizers.append(qtrl_optimizer)

# TODO: we can deal with proper handling later
if alg == "RL":
elif alg == "RL":
rl_env = _RL(
objectives,
control_parameters,
Expand Down
28 changes: 13 additions & 15 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def sin_z_jax(t, r, **kwargs):
)

# ----------------------- RL --------------------
# TODO: this is the input for optimiz_pulses() function
# you can use this routine to test your implementation

# state to state transfer
initial = qt.basis(2, 0)
Expand All @@ -169,38 +167,38 @@ def sin_z_jax(t, r, **kwargs):

state2state_rl = Case(
objectives=[Objective(initial, H, target)],
#control_parameters={"bounds": [-13, 13]}, # TODO: for now only consider bounds
control_parameters = {
"p": {"bounds": [(-13, 13)],}
},
tlist=np.linspace(0, 10, 100),
algorithm_kwargs={
"fid_err_targ": 0.01,
"alg": "RL",
"max_iter": 70000,
"max_iter": 300,
"shorter_pulses": True,
},
optimizer_kwargs={}
)

# no big difference for unitary evolution

#initial = qt.qeye(2) # Identity
#target = qt.gates.hadamard_transform()
initial = qt.qeye(2) # Identity
target = qt.gates.hadamard_transform()

#unitary_rl = state2state_rl._replace(
# objectives=[Objective(initial, H, target)],
#)
unitary_rl = state2state_rl._replace(
objectives=[Objective(initial, H, target)],
)


@pytest.fixture(
params=[
flowerthrower marked this conversation as resolved.
Show resolved Hide resolved
flowerthrower marked this conversation as resolved.
Show resolved Hide resolved
#pytest.param(state2state_grape, id="State to state (GRAPE)"),
#pytest.param(state2state_crab, id="State to state (CRAB)"),
#pytest.param(state2state_param_crab, id="State to state (param. CRAB)"),
#pytest.param(state2state_goat, id="State to state (GOAT)"),
#pytest.param(state2state_jax, id="State to state (JAX)"),
pytest.param(state2state_grape, id="State to state (GRAPE)"),
pytest.param(state2state_crab, id="State to state (CRAB)"),
pytest.param(state2state_param_crab, id="State to state (param. CRAB)"),
pytest.param(state2state_goat, id="State to state (GOAT)"),
pytest.param(state2state_jax, id="State to state (JAX)"),
pytest.param(state2state_rl, id="State to state (RL)"),
#pytest.param(unitary_rl, id="Unitary (RL)"),
pytest.param(unitary_rl, id="Unitary (RL)"),
]
)
def tst(request):
Expand Down