diff --git a/compiler_opt/es/policy_utils_test.py b/compiler_opt/es/policy_utils_test.py index 2603fe1a..240517ed 100644 --- a/compiler_opt/es/policy_utils_test.py +++ b/compiler_opt/es/policy_utils_test.py @@ -110,9 +110,8 @@ class VectorTest(absltest.TestCase): params = np.arange(expected_length_of_a_perturbation, dtype=np.float32) POLICY_NAME = 'test_policy_name' - # TODO(abenalaast): Issue #280 - def test_set_vectorized_parameters_for_policy(self): - # create a policy + def _save_inlining_policy( + self) -> tuple[str, actor_policy.ActorPolicy, policy_saver.PolicySaver]: problem_config = registry.get_configuration( implementation=inlining.InliningConfig) time_step_spec, action_spec = problem_config.get_signature_spec() @@ -143,6 +142,12 @@ def test_set_vectorized_parameters_for_policy(self): testing_path = self.create_tempdir() policy_save_path = os.path.join(testing_path, 'temp_output', 'policy') saver.save(policy_save_path) + return (policy_save_path, policy, saver) + + # TODO(abenalaast): Issue #280 + def test_set_vectorized_parameters_for_policy(self): + # create a policy + policy_save_path, policy, _ = self._save_inlining_policy() # set the values of the policy variables policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params) @@ -177,36 +182,7 @@ def test_set_vectorized_parameters_for_policy(self): # TODO(abenalaast): Issue #280 def test_get_vectorized_parameters_from_policy(self): # create a policy - problem_config = registry.get_configuration( - implementation=inlining.InliningConfig) - time_step_spec, action_spec = problem_config.get_signature_spec() - quantile_file_dir = os.path.join(TEST_PATH_PREFIX, 'compiler_opt', 'rl', - 'inlining', 'vocab') - creator = inlining_config.get_observation_processing_layer_creator( - quantile_file_dir=quantile_file_dir, - with_sqrt=False, - with_z_score_normalization=False) - layers = tf.nest.map_structure(creator, time_step_spec.observation) - - actor_network = actor_distribution_network.ActorDistributionNetwork( - input_tensor_spec=time_step_spec.observation, - output_tensor_spec=action_spec, - preprocessing_layers=layers, - preprocessing_combiner=tf.keras.layers.Concatenate(), - fc_layer_params=(64, 64, 64, 64), - dropout_layer_params=None, - activation_fn=tf.keras.activations.relu) - - policy = actor_policy.ActorPolicy( - time_step_spec=time_step_spec, - action_spec=action_spec, - actor_network=actor_network) - - # save the policy - saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy}) - testing_path = self.create_tempdir() - policy_save_path = os.path.join(testing_path, 'temp_output', 'policy') - saver.save(policy_save_path) + policy_save_path, policy, _ = self._save_inlining_policy() # functionality verified in previous test policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params) @@ -226,36 +202,7 @@ def test_get_vectorized_parameters_from_policy(self): # TODO(abenalaast): Issue #280 def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self): # create a policy - problem_config = registry.get_configuration( - implementation=inlining.InliningConfig) - time_step_spec, action_spec = problem_config.get_signature_spec() - quantile_file_dir = os.path.join(TEST_PATH_PREFIX, 'compiler_opt', 'rl', - 'inlining', 'vocab') - creator = inlining_config.get_observation_processing_layer_creator( - quantile_file_dir=quantile_file_dir, - with_sqrt=False, - with_z_score_normalization=False) - layers = tf.nest.map_structure(creator, time_step_spec.observation) - - actor_network = actor_distribution_network.ActorDistributionNetwork( - input_tensor_spec=time_step_spec.observation, - output_tensor_spec=action_spec, - preprocessing_layers=layers, - preprocessing_combiner=tf.keras.layers.Concatenate(), - fc_layer_params=(64, 64, 64, 64), - dropout_layer_params=None, - activation_fn=tf.keras.activations.relu) - - policy = actor_policy.ActorPolicy( - time_step_spec=time_step_spec, - action_spec=action_spec, - actor_network=actor_network) - - # save the policy - saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy}) - testing_path = self.create_tempdir() - policy_save_path = os.path.join(testing_path, 'temp_output', 'policy') - saver.save(policy_save_path) + policy_save_path, policy, saver = self._save_inlining_policy() # set the values of the variables policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)