Skip to content

Commit

Permalink
Get files with extension (bazelbuild#979)
Browse files Browse the repository at this point in the history
* Get files helper function

* Remove extra computation

* Move extension to shared file

* Move functions to paths.bzl
  • Loading branch information
borkaehw authored Jan 31, 2020
1 parent cbf3093 commit fadf4ce
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 43 deletions.
12 changes: 12 additions & 0 deletions scala/private/paths.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
java_extension = ".java"

scala_extension = ".scala"

srcjar_extension = ".srcjar"

def get_files_with_extension(ctx, extension):
return [
f
for f in ctx.files.srcs
if f.basename.endswith(extension)
]
58 changes: 16 additions & 42 deletions scala/private/phases/phase_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ load(
"@io_bazel_rules_scala//scala/private:coverage_replacements_provider.bzl",
_coverage_replacements_provider = "coverage_replacements_provider",
)
load(
"@io_bazel_rules_scala//scala/private:paths.bzl",
_get_files_with_extension = "get_files_with_extension",
_java_extension = "java_extension",
_scala_extension = "scala_extension",
_srcjar_extension = "srcjar_extension",
)
load(
"@io_bazel_rules_scala//scala/private:rule_impls.bzl",
_compile_scala = "compile_scala",
_expand_location = "expand_location",
)
load(":resources.bzl", _resource_paths = "paths")

_java_extension = ".java"

_scala_extension = ".scala"

_srcjar_extension = ".srcjar"

_empty_coverage_struct = struct(
external = struct(
replacements = {},
Expand Down Expand Up @@ -197,29 +198,17 @@ def _compile_or_empty(
merged_provider = scala_compilation_provider,
)
else:
in_srcjars = [
f
for f in ctx.files.srcs
if f.basename.endswith(_srcjar_extension)
]
java_srcs = _get_files_with_extension(ctx, _java_extension)
scala_srcs = _get_files_with_extension(ctx, _scala_extension)
in_srcjars = _get_files_with_extension(ctx, _srcjar_extension)
all_srcjars = depset(in_srcjars, transitive = [srcjars])

java_srcs = [
f
for f in ctx.files.srcs
if f.basename.endswith(_java_extension)
]

# We are not able to verify whether dependencies are used when compiling java sources
# Thus we disable unused dependency checking when java sources are found
if len(java_srcs) != 0:
unused_dependency_checker_mode = "off"

sources = [
f
for f in ctx.files.srcs
if f.basename.endswith(_scala_extension)
] + java_srcs
sources = scala_srcs + java_srcs
_compile_scala(
ctx,
ctx.label,
Expand Down Expand Up @@ -258,7 +247,7 @@ def _compile_or_empty(
# so set ijar == jar
ijar = ctx.outputs.jar

source_jar = _pack_source_jar(ctx)
source_jar = _pack_source_jar(ctx, scala_srcs, in_srcjars)
scala_compilation_provider = _create_scala_compilation_provider(ctx, ijar, source_jar, deps_providers)

# compile the java now
Expand Down Expand Up @@ -339,31 +328,16 @@ def _create_scala_compilation_provider(ctx, ijar, source_jar, deps_providers):
runtime_deps = runtime_deps,
)

def _pack_source_jar(ctx):
# collect .scala sources and pack a source jar for Scala
scala_sources = [
f
for f in ctx.files.srcs
if f.basename.endswith(_scala_extension)
]

# collect .srcjar files and pack them with the scala sources
bundled_source_jars = [
f
for f in ctx.files.srcs
if f.basename.endswith(_srcjar_extension)
]
scala_source_jar = java_common.pack_sources(
def _pack_source_jar(ctx, scala_srcs, in_srcjars):
return java_common.pack_sources(
ctx.actions,
output_jar = ctx.outputs.jar,
sources = scala_sources,
source_jars = bundled_source_jars,
sources = scala_srcs,
source_jars = in_srcjars,
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),
host_javabase = find_java_runtime_toolchain(ctx, ctx.attr._host_javabase),
)

return scala_source_jar

def _jacoco_offline_instrument(ctx, input_jar):
if not ctx.configuration.coverage_enabled or not hasattr(ctx.attr, "_code_coverage_instrumentation_worker"):
return _empty_coverage_struct
Expand Down
7 changes: 6 additions & 1 deletion scala/private/phases/phase_scalafmt.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
#
# Outputs to format the scala files when it is explicitly specified
#
load(
"@io_bazel_rules_scala//scala/private:paths.bzl",
_scala_extension = "scala_extension",
)

def phase_scalafmt(ctx, p):
if ctx.attr.format:
manifest, files = _build_format(ctx)
Expand All @@ -17,7 +22,7 @@ def _build_format(ctx):
manifest_content = []
for src in ctx.files.srcs:
# only format scala source files, not generated files
if src.path.endswith(".scala") and src.is_source:
if src.path.endswith(_scala_extension) and src.is_source:
file = ctx.actions.declare_file("{}.fmt.output".format(src.short_path))
files.append(file)
ctx.actions.run(
Expand Down

0 comments on commit fadf4ce

Please sign in to comment.