Skip to content

Commit c3e071a

Browse files
committed
assume role boto clients
1 parent 5399278 commit c3e071a

21 files changed

+153
-165
lines changed

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/item_reader/resource_eval/resource_eval_s3.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.item_reader.resource_eval.resource_eval import (
66
ResourceEval,
77
)
8+
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
9+
StateCredentials,
10+
)
811
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
912
ResourceRuntimePart,
1013
)
@@ -15,31 +18,41 @@
1518

1619
class ResourceEvalS3(ResourceEval):
1720
_HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_"
18-
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None]
21+
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart, StateCredentials], None]
1922

2023
@staticmethod
21-
def _get_s3_client(resource_runtime_part: ResourceRuntimePart):
24+
def _get_s3_client(
25+
resource_runtime_part: ResourceRuntimePart, state_credentials: StateCredentials
26+
):
2227
return boto_client_for(
23-
region=resource_runtime_part.region,
24-
account=resource_runtime_part.account,
25-
service="s3",
28+
region=resource_runtime_part.region, service="s3", state_credentials=state_credentials
2629
)
2730

2831
@staticmethod
29-
def _handle_get_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None:
30-
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
32+
def _handle_get_object(
33+
env: Environment,
34+
resource_runtime_part: ResourceRuntimePart,
35+
state_credentials: StateCredentials,
36+
) -> None:
37+
s3_client = ResourceEvalS3._get_s3_client(
38+
resource_runtime_part=resource_runtime_part, state_credentials=state_credentials
39+
)
3140
parameters = env.stack.pop()
32-
response = s3_client.get_object(**parameters)
41+
response = s3_client.get_object(**parameters) # noqa
3342
content = to_str(response["Body"].read())
3443
env.stack.append(content)
3544

3645
@staticmethod
3746
def _handle_list_objects_v2(
38-
env: Environment, resource_runtime_part: ResourceRuntimePart
47+
env: Environment,
48+
resource_runtime_part: ResourceRuntimePart,
49+
state_credentials: StateCredentials,
3950
) -> None:
40-
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
51+
s3_client = ResourceEvalS3._get_s3_client(
52+
resource_runtime_part=resource_runtime_part, state_credentials=state_credentials
53+
)
4154
parameters = env.stack.pop()
42-
response = s3_client.list_objects_v2(**parameters)
55+
response = s3_client.list_objects_v2(**parameters) # noqa
4356
contents = response["Contents"]
4457
env.stack.append(contents)
4558

@@ -55,4 +68,5 @@ def eval_resource(self, env: Environment) -> None:
5568
self.resource.eval(env=env)
5669
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
5770
resolver_handler = self._get_api_action_handler()
58-
resolver_handler(env, resource_runtime_part)
71+
state_credentials = StateCredentials(role_arn=env.aws_execution_details.role_arn)
72+
resolver_handler(env, resource_runtime_part, state_credentials)

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/result_writer/resource_eval/resource_eval_s3.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.result_writer.resource_eval.resource_eval import (
77
ResourceEval,
88
)
9+
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
10+
StateCredentials,
11+
)
912
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
1013
ResourceRuntimePart,
1114
)
@@ -16,22 +19,28 @@
1619

1720
class ResourceEvalS3(ResourceEval):
1821
_HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_"
19-
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None]
22+
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart, StateCredentials], None]
2023

2124
@staticmethod
22-
def _get_s3_client(resource_runtime_part: ResourceRuntimePart):
25+
def _get_s3_client(
26+
resource_runtime_part: ResourceRuntimePart, state_credentials: StateCredentials
27+
):
2328
return boto_client_for(
24-
region=resource_runtime_part.region,
25-
account=resource_runtime_part.account,
26-
service="s3",
29+
service="s3", region=resource_runtime_part.region, state_credentials=state_credentials
2730
)
2831

2932
@staticmethod
30-
def _handle_put_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None:
33+
def _handle_put_object(
34+
env: Environment,
35+
resource_runtime_part: ResourceRuntimePart,
36+
state_credentials: StateCredentials,
37+
) -> None:
3138
parameters = env.stack.pop()
3239
env.stack.pop() # TODO: results
3340

34-
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
41+
s3_client = ResourceEvalS3._get_s3_client(
42+
resource_runtime_part=resource_runtime_part, state_credentials=state_credentials
43+
)
3544
map_run_record = env.map_run_record_pool_manager.get_all().pop()
3645
map_run_uuid = map_run_record.map_run_arn.split(":")[-1]
3746
if parameters["Prefix"] != "" and not parameters["Prefix"].endswith("/"):
@@ -66,4 +75,5 @@ def eval_resource(self, env: Environment) -> None:
6675
self.resource.eval(env=env)
6776
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
6877
resolver_handler = self._get_api_action_handler()
69-
resolver_handler(env, resource_runtime_part)
78+
state_credentials = StateCredentials(role_arn=env.aws_execution_details.role_arn)
79+
resolver_handler(env, resource_runtime_part, state_credentials)

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/credentials.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from typing import Final, Optional
1+
from dataclasses import dataclass
2+
from typing import Final
23

34
from localstack.services.stepfunctions.asl.component.common.string.string_expression import (
45
StringExpression,
56
)
67
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
78
from localstack.services.stepfunctions.asl.eval.environment import Environment
89

9-
_CREDENTIALS_ROLE_ARN_KEY: Final[str] = "RoleArn"
10-
ComputedCredentials = dict
10+
11+
@dataclass
12+
class StateCredentials:
13+
role_arn: str
1114

1215

1316
class RoleArn(EvalComponent):
@@ -26,12 +29,8 @@ class Credentials(EvalComponent):
2629
def __init__(self, role_arn: RoleArn):
2730
self.role_arn = role_arn
2831

29-
@staticmethod
30-
def get_role_arn_from(computed_credentials: ComputedCredentials) -> Optional[str]:
31-
return computed_credentials.get(_CREDENTIALS_ROLE_ARN_KEY)
32-
3332
def _eval_body(self, env: Environment) -> None:
3433
self.role_arn.eval(env=env)
3534
role_arn = env.stack.pop()
36-
computes_credentials: ComputedCredentials = {_CREDENTIALS_ROLE_ARN_KEY: role_arn}
35+
computes_credentials = StateCredentials(role_arn=role_arn)
3736
env.stack.append(computes_credentials)

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from localstack.aws.api.lambda_ import InvocationResponse
66
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
7-
ComputedCredentials,
7+
StateCredentials,
88
)
99
from localstack.services.stepfunctions.asl.eval.environment import Environment
1010
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
@@ -39,10 +39,10 @@ def _from_payload(payload_streaming_body: IO[bytes]) -> Union[json, str]:
3939

4040

4141
def exec_lambda_function(
42-
env: Environment, parameters: dict, region: str, account: str, credentials: ComputedCredentials
42+
env: Environment, parameters: dict, region: str, state_credentials: StateCredentials
4343
) -> None:
4444
lambda_client = boto_client_for(
45-
region=region, account=account, service="lambda", credentials=credentials
45+
service="lambda", region=region, state_credentials=state_credentials
4646
)
4747

4848
invocation_resp: InvocationResponse = lambda_client.invoke(**parameters)

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
StatesErrorNameType,
3131
)
3232
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
33-
ComputedCredentials,
34-
Credentials,
33+
StateCredentials,
3534
)
3635
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
3736
ResourceRuntimePart,
@@ -235,15 +234,15 @@ def _eval_service_task(
235234
env: Environment,
236235
resource_runtime_part: ResourceRuntimePart,
237236
normalised_parameters: dict,
238-
task_credentials: ComputedCredentials,
237+
state_credentials: StateCredentials,
239238
): ...
240239

241240
def _before_eval_execution(
242241
self,
243242
env: Environment,
244243
resource_runtime_part: ResourceRuntimePart,
245244
raw_parameters: dict,
246-
task_credentials: TaskCredentials,
245+
state_credentials: StateCredentials,
247246
) -> None:
248247
parameters_str = to_json_str(raw_parameters)
249248

@@ -263,7 +262,7 @@ def _before_eval_execution(
263262
scheduled_event_details["heartbeatInSeconds"] = heartbeat_seconds
264263
if self.credentials:
265264
scheduled_event_details["taskCredentials"] = TaskCredentials(
266-
roleArn=Credentials.get_role_arn_from(computed_credentials=task_credentials)
265+
roleArn=state_credentials.role_arn
267266
)
268267
env.event_manager.add_event(
269268
context=env.event_history_context,
@@ -286,7 +285,7 @@ def _after_eval_execution(
286285
env: Environment,
287286
resource_runtime_part: ResourceRuntimePart,
288287
normalised_parameters: dict,
289-
task_credentials: ComputedCredentials,
288+
state_credentials: StateCredentials,
290289
) -> None:
291290
output = env.stack[-1]
292291
self._verify_size_quota(env=env, value=output)
@@ -308,13 +307,13 @@ def _eval_execution(self, env: Environment) -> None:
308307
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
309308

310309
raw_parameters = self._eval_parameters(env=env)
311-
task_credentials = self._eval_credentials(env=env)
310+
state_credentials = self._eval_state_credentials(env=env)
312311

313312
self._before_eval_execution(
314313
env=env,
315314
resource_runtime_part=resource_runtime_part,
316315
raw_parameters=raw_parameters,
317-
task_credentials=task_credentials,
316+
state_credentials=state_credentials,
318317
)
319318

320319
normalised_parameters = copy.deepcopy(raw_parameters)
@@ -324,7 +323,7 @@ def _eval_execution(self, env: Environment) -> None:
324323
env=env,
325324
resource_runtime_part=resource_runtime_part,
326325
normalised_parameters=normalised_parameters,
327-
task_credentials=task_credentials,
326+
state_credentials=state_credentials,
328327
)
329328

330329
output_value = env.stack[-1]
@@ -334,5 +333,5 @@ def _eval_execution(self, env: Environment) -> None:
334333
env=env,
335334
resource_runtime_part=resource_runtime_part,
336335
normalised_parameters=normalised_parameters,
337-
task_credentials=task_credentials,
336+
state_credentials=state_credentials,
338337
)

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
FailureEvent,
2525
)
2626
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
27-
ComputedCredentials,
27+
StateCredentials,
2828
)
2929
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
3030
ResourceCondition,
@@ -294,7 +294,7 @@ def _eval_service_task(
294294
env: Environment,
295295
resource_runtime_part: ResourceRuntimePart,
296296
normalised_parameters: dict,
297-
task_credentials: ComputedCredentials,
297+
state_credentials: StateCredentials,
298298
):
299299
# TODO: add support for task credentials
300300

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
StatesErrorNameType,
1616
)
1717
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
18-
ComputedCredentials,
18+
StateCredentials,
1919
)
2020
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
2121
ResourceCondition,
@@ -128,15 +128,14 @@ def _eval_service_task(
128128
env: Environment,
129129
resource_runtime_part: ResourceRuntimePart,
130130
normalised_parameters: dict,
131-
task_credentials: ComputedCredentials,
131+
state_credentials: StateCredentials,
132132
):
133133
service_name = self._get_boto_service_name()
134134
api_action = self._get_boto_service_action()
135135
api_client = boto_client_for(
136-
region=resource_runtime_part.region,
137-
account=resource_runtime_part.account,
138136
service=service_name,
139-
credentials=task_credentials,
137+
region=resource_runtime_part.region,
138+
state_credentials=state_credentials,
140139
)
141140
response = getattr(api_client, api_action)(**normalised_parameters) or dict()
142141
if response:

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_batch.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
StatesErrorNameType,
1919
)
2020
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
21-
ComputedCredentials,
21+
StateCredentials,
2222
)
2323
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
2424
ResourceCondition,
@@ -90,15 +90,15 @@ def _before_eval_execution(
9090
env: Environment,
9191
resource_runtime_part: ResourceRuntimePart,
9292
raw_parameters: dict,
93-
task_credentials: ComputedCredentials,
93+
state_credentials: StateCredentials,
9494
) -> None:
9595
if self.resource.condition == ResourceCondition.Sync:
9696
self._attach_aws_environment_variables(parameters=raw_parameters)
9797
super()._before_eval_execution(
9898
env=env,
9999
resource_runtime_part=resource_runtime_part,
100100
raw_parameters=raw_parameters,
101-
task_credentials=task_credentials,
101+
state_credentials=state_credentials,
102102
)
103103

104104
def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
@@ -138,12 +138,12 @@ def _build_sync_resolver(
138138
env: Environment,
139139
resource_runtime_part: ResourceRuntimePart,
140140
normalised_parameters: dict,
141-
task_credentials: ComputedCredentials,
141+
state_credentials: StateCredentials,
142142
) -> Callable[[], Optional[Any]]:
143143
batch_client = boto_client_for(
144-
region=resource_runtime_part.region,
145-
account=resource_runtime_part.account,
146144
service="batch",
145+
region=resource_runtime_part.region,
146+
state_credentials=state_credentials,
147147
)
148148
submission_output: dict = env.stack.pop()
149149
job_id = submission_output["JobId"]
@@ -186,15 +186,14 @@ def _eval_service_task(
186186
env: Environment,
187187
resource_runtime_part: ResourceRuntimePart,
188188
normalised_parameters: dict,
189-
task_credentials: ComputedCredentials,
189+
state_credentials: StateCredentials,
190190
):
191191
service_name = self._get_boto_service_name()
192192
api_action = self._get_boto_service_action()
193193
batch_client = boto_client_for(
194194
region=resource_runtime_part.region,
195-
account=resource_runtime_part.account,
196195
service=service_name,
197-
credentials=task_credentials,
196+
state_credentials=state_credentials,
198197
)
199198
response = getattr(batch_client, api_action)(**normalised_parameters)
200199
response.pop("ResponseMetadata", None)

0 commit comments

Comments
 (0)