Skip to content

Commit

Permalink
Add environment variables in job definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
srisco committed Jun 26, 2019
1 parent a71208f commit bee9a9d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 23 deletions.
73 changes: 54 additions & 19 deletions scar/providers/aws/batchfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from scar.providers.aws import GenericClient
import scar.logger as logger
from scar.providers.aws.launchtemplates import LaunchTemplates
from scar.utils import DataTypesUtils
from scar.utils import DataTypesUtils, FileUtils, StrUtils


class Batch(GenericClient):
Expand All @@ -28,13 +29,16 @@ def launch_templates(self):
def __init__(self, aws_properties, supervisor_version):
super().__init__(aws_properties)
self.supervisor_version = supervisor_version
self.script = self._get_user_script()
self._initialize_properties()

def _initialize_properties(self):
self.aws.batch.instance_role = "arn:aws:iam::{0}:instance-profile/ecsInstanceRole".format(
self.aws.account_id)
self.aws.batch.service_role = "arn:aws:iam::{0}:role/service-role/AWSBatchServiceRole".format(
self.aws.account_id)
self.aws.batch.env_vars = []
self._set_required_environment_variables()

def exist_compute_environments(self, name):
creation_args = self.get_describe_compute_env_args(name_c=name)
Expand All @@ -46,6 +50,13 @@ def delete_compute_environment(self, name):
self._delete_job_queue(name)
self._delete_compute_env(name)

def _get_user_script(self):
script = ''
if self.aws._lambda.init_script:
file_content = FileUtils.read_file(self.aws._lambda.init_script)
script = StrUtils.utf8_to_base64_string(file_content)
return script

def _get_job_definitions(self, jobs_info):
return ["{0}:{1}".format(definition['jobDefinitionName'], definition['revision']) for definition in jobs_info['jobDefinitions']]

Expand Down Expand Up @@ -179,7 +190,7 @@ def _get_job_definition_args(self):
'name': 'supervisor-bin'
}
],
'environment': self._get_job_env_vars(),
'environment': self.aws.batch.env_vars,
'mountPoints': [
{
'containerPath': '/opt/faas-supervisor/bin',
Expand All @@ -197,32 +208,56 @@ def _get_job_definition_args(self):
]
return job_def_args

def _parse_vars(self, env_vars):
env_vars = []
def _add_custom_environment_variables(self, env_vars):
if isinstance(env_vars, dict):
for key, val in env_vars.items():
env_vars.append({
'name': key,
'value': val
})
self._set_batch_environment_variable(key, val)
else:
for env_var in env_vars:
key_val = env_var.split("=")
env_vars.append({
'name': key_val[0],
'value': key_val[1]
})
return env_vars

def _get_job_env_vars(self):
env_vars = []
self._set_batch_environment_variable(key_val[0], key_val[1])

def _set_batch_environment_variable(self, key, value):
self.aws.batch.env_vars.append({
'name': key,
'value': value
})

def _add_s3_environment_vars(self):
if hasattr(self.aws, "s3"):
provider_id = random.randint(1, 1000001)

if hasattr(self.aws.s3, "input_bucket"):
self._set_batch_environment_variable(
f'STORAGE_PATH_INPUT_{provider_id}',
self.aws.s3.storage_path_input
)
if hasattr(self.aws.s3, "output_bucket"):
self._set_batch_environment_variable(
f'STORAGE_PATH_OUTPUT_{provider_id}',
self.aws.s3.storage_path_output
)
else:
self._set_batch_environment_variable(
f'STORAGE_PATH_OUTPUT_{provider_id}',
self.aws.s3.storage_path_input
)
self._set_batch_environment_variable(
f'STORAGE_AUTH_S3_USER_{provider_id}',
'scar'
)

def _set_required_environment_variables(self):
self._set_batch_environment_variable('AWS_LAMBDA_FUNCTION_NAME', self.aws._lambda.name)
if self.script:
self._set_batch_environment_variable('SCRIPT', self.script)
if (hasattr(self.aws._lambda, 'environment_variables') and
self.aws._lambda.environment_variables):
env_vars.extend(self._parse_vars(self.aws._lambda.environment_variables))
self._add_custom_environment_variables(self.aws._lambda.environment_variables)
if (hasattr(self.aws._lambda, 'lambda_environment') and
self.aws._lambda.lambda_environment):
env_vars.extend(self._parse_vars(self.aws._lambda.lambda_environment))
return env_vars
self._add_custom_environment_variables(self.aws._lambda.lambda_environment)
self._add_s3_environment_vars()

def get_state_and_status_of_compute_env(self, name=None):
creation_args = self.get_describe_compute_env_args(name_c=name)
Expand Down
20 changes: 16 additions & 4 deletions scar/providers/aws/lambdafunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,25 @@ def _add_s3_environment_vars(self):
provider_id = random.randint(1, 1000001)

if hasattr(self.aws.s3, "input_bucket"):
self._add_lambda_environment_variable('STORAGE_PATH_INPUT_{}'.format(provider_id), self.aws.s3.storage_path_input)
self._add_lambda_environment_variable(
f'STORAGE_PATH_INPUT_{provider_id}',
self.aws.s3.storage_path_input
)

if hasattr(self.aws.s3, "output_bucket"):
self._add_lambda_environment_variable('STORAGE_PATH_OUTPUT_{}'.format(provider_id), self.aws.s3.storage_path_output)
self._add_lambda_environment_variable(
f'STORAGE_PATH_OUTPUT_{provider_id}',
self.aws.s3.storage_path_output
)
else:
self._add_lambda_environment_variable('STORAGE_PATH_OUTPUT_{}'.format(provider_id), self.aws.s3.storage_path_input)
self._add_lambda_environment_variable('STORAGE_AUTH_S3_{}_USER'.format(provider_id), "scar")
self._add_lambda_environment_variable(
f'STORAGE_PATH_OUTPUT_{provider_id}',
self.aws.s3.storage_path_input
)
self._add_lambda_environment_variable(
f'STORAGE_AUTH_S3_USER_{provider_id}',
'scar'
)

@excp.exception(logger)
def _set_function_code(self):
Expand Down

0 comments on commit bee9a9d

Please sign in to comment.