Skip to content

Commit

Permalink
Add support for getters and a Flatten impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh Wills committed Dec 2, 2014
1 parent 47fb4c1 commit cb91f2a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 61 deletions.
12 changes: 6 additions & 6 deletions src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
*/
package com.cloudera.dataflow.spark;

import com.google.api.client.util.Lists;
import com.google.cloud.dataflow.sdk.runners.PipelineOptions;
import com.google.cloud.dataflow.sdk.streaming.KeyedState;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
Expand All @@ -28,6 +27,7 @@
import org.apache.spark.api.java.function.FlatMapFunction;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class DoFnFunction<I, O> implements FlatMapFunction<Iterator<I>, O> {
Expand All @@ -50,9 +50,9 @@ public Iterable<O> call(Iterator<I> iter) throws Exception {
return ctxt.outputs;
}

private static class ProcCtxt<I, O> extends DoFn<I, O>.ProcessContext {
private class ProcCtxt<I, O> extends DoFn<I, O>.ProcessContext {

private List<O> outputs = Lists.newArrayList();
private List<O> outputs = new LinkedList<>();
private I element;

public ProcCtxt(DoFn<I, O> fn) {
Expand All @@ -65,7 +65,7 @@ public PipelineOptions getPipelineOptions() {
}

@Override
public void output(O o) {
public synchronized void output(O o) {
outputs.add(o);
}

Expand All @@ -75,14 +75,14 @@ public <T> void sideOutput(TupleTag<T> tupleTag, T t) {

@Override
public <AI, AA, AO> Aggregator<AI> createAggregator(
String s,
String named,
Combine.CombineFn<? super AI, AA, AO> combineFn) {
return null;
}

@Override
public <AI, AO> Aggregator<AI> createAggregator(
String s,
String named,
SerializableFunction<Iterable<AI>, AO> sfunc) {
return null;
}
Expand Down
102 changes: 48 additions & 54 deletions src/main/java/com/cloudera/dataflow/spark/SparkPipelineRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
import com.google.cloud.dataflow.sdk.transforms.Combine;
import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.Flatten;
import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.dataflow.sdk.values.PCollectionList;
import com.google.cloud.dataflow.sdk.values.POutput;
import com.google.cloud.dataflow.sdk.values.PValue;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -65,7 +69,7 @@ public SparkPipelineRunner(String master) {

@Override
public EvaluationResult run(Pipeline pipeline) {
EvaluationContext ctxt = new EvaluationContext(this.master);
EvaluationContext ctxt = new EvaluationContext(this.master, pipeline);
pipeline.traverseTopologically(new Evaluator(ctxt));
return ctxt;
}
Expand Down Expand Up @@ -108,104 +112,81 @@ public static interface EvaluationResult extends PipelineResult {

public static class EvaluationContext implements EvaluationResult {
final JavaSparkContext jsc;
JavaRDDLike last;
final Pipeline pipeline;
final Map<PValue, JavaRDDLike> rdds = Maps.newHashMap();

public EvaluationContext(String master) {
public EvaluationContext(String master, Pipeline pipeline) {
this.jsc = new JavaSparkContext(master, "dataflow");
this.pipeline = pipeline;
}

JavaSparkContext getSparkContext() {
return jsc;
}
void setLast(JavaRDDLike rdd) {
last = rdd;
}

JavaRDDLike getLast() { return last; }
}
Pipeline getPipeline() { return pipeline; }

public static interface TransformEvaluator<PT extends PTransform> extends Serializable {
void evaluate(PT transform, EvaluationContext context);
}
void setRDD(PTransform transform, JavaRDDLike rdd) {
rdds.put((PValue) pipeline.getOutput(transform), rdd);
}

private static class FieldGetter {
private Map<String, Field> fields;

public FieldGetter(Class<?> clazz) {
this.fields = Maps.newHashMap();
for (Field f : clazz.getDeclaredFields()) {
try {
f.setAccessible(true);
this.fields.put(f.getName(), f);
System.err.println("Field " + f.getName() + " for class = " + clazz);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
JavaRDDLike getRDD(PCollection pcollection) {
return rdds.get(pcollection);
}

public <T> T get(String fieldname, Object value) {
try {
return (T) fields.get(fieldname).get(value);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
JavaRDDLike getRDD(PTransform transform) {
return rdds.get(pipeline.getInput(transform));
}
}

private static FieldGetter READ_TEXT_FG = new FieldGetter(TextIO.Read.Bound.class);
private static FieldGetter WRITE_TEXT_FG = new FieldGetter(TextIO.Write.Bound.class);
public static interface TransformEvaluator<PT extends PTransform> extends Serializable {
void evaluate(PT transform, EvaluationContext context);
}

private static TransformEvaluator<TextIO.Read.Bound> READ_TEXT = new TransformEvaluator<TextIO.Read.Bound>() {
@Override
public void evaluate(TextIO.Read.Bound transform, EvaluationContext context) {
String pattern = READ_TEXT_FG.get("filepattern", transform);
String pattern = transform.getFilepattern();
JavaRDD rdd = context.getSparkContext().textFile(pattern);
Coder coder = READ_TEXT_FG.get("coder", transform);
if (coder != null) {
//TODO
}
context.setLast(rdd);
// TODO: handle coders
context.setRDD(transform, rdd);
}
};

private static TransformEvaluator<TextIO.Write.Bound> WRITE_TEXT = new TransformEvaluator<TextIO.Write.Bound>() {
@Override
public void evaluate(TextIO.Write.Bound transform, EvaluationContext context) {
JavaRDDLike last = context.getLast();
Coder coder = WRITE_TEXT_FG.get("coder", transform);
JavaRDDLike last = context.getRDD(transform);
Coder coder = null;
if (coder != null) {
//TODO
}
String pattern = WRITE_TEXT_FG.get("filenamePrefix", transform);
String pattern = transform.getFilenamePrefix();
last.saveAsTextFile(pattern);
}
};

private static FieldGetter CREATE_FG = new FieldGetter(Create.class);
private static TransformEvaluator<Create> CREATE = new TransformEvaluator<Create>() {
@Override
public void evaluate(Create transform, EvaluationContext context) {
Iterable elems = CREATE_FG.get("elems", transform);
Iterable elems = transform.getElems();
JavaRDD rdd = context.getSparkContext().parallelize(Lists.newLinkedList(elems));
context.setLast(rdd);
context.setRDD(transform, rdd);
}
};

private static FieldGetter PARDO_FG = new FieldGetter(ParDo.Bound.class);
private static TransformEvaluator<ParDo.Bound> PARDO = new TransformEvaluator<ParDo.Bound>() {
@Override
public void evaluate(ParDo.Bound transform, EvaluationContext context) {
JavaRDDLike last = context.getLast();
DoFnFunction dofn = new DoFnFunction(PARDO_FG.<DoFn>get("fn", transform));
context.setLast(last.mapPartitions(dofn));
JavaRDDLike last = context.getRDD(transform);
DoFnFunction dofn = new DoFnFunction(transform.getFn());
context.setRDD(transform, last.mapPartitions(dofn));
}
};

private static TransformEvaluator<GroupByKey> GBK = new TransformEvaluator<GroupByKey>() {
@Override
public void evaluate(GroupByKey transform, EvaluationContext context) {
context.setLast(fromPair(toPair(context.getLast()).groupByKey()));
context.setRDD(transform, fromPair(toPair(context.getRDD(transform)).groupByKey()));
}

private JavaPairRDD toPair(JavaRDDLike rdd) {
Expand All @@ -229,12 +210,11 @@ public Object call(Object o) throws Exception {
}
};

private static FieldGetter GROUPED_FG = new FieldGetter(Combine.GroupedValues.class);
private static TransformEvaluator<Combine.GroupedValues> GROUPED = new TransformEvaluator<Combine.GroupedValues>() {
@Override
public void evaluate(Combine.GroupedValues transform, EvaluationContext context) {
final Combine.KeyedCombineFn keyed = GROUPED_FG.get("fn", transform);
context.setLast(context.getLast().map(new Function() {
final Combine.KeyedCombineFn keyed = transform.getFn();
context.setRDD(transform, context.getRDD(transform).map(new Function() {
@Override
public Object call(Object input) throws Exception {
KV<Object, Iterable> kv = (KV<Object, Iterable>) input;
Expand All @@ -244,12 +224,26 @@ public Object call(Object input) throws Exception {
}
};

private static TransformEvaluator<Flatten> FLATTEN = new TransformEvaluator<Flatten>() {
@Override
public void evaluate(Flatten transform, EvaluationContext context) {
PCollectionList<?> pcs = (PCollectionList<?>) context.getPipeline().getInput(transform);
JavaRDD[] rdds = new JavaRDD[pcs.size()];
for (int i = 0; i < rdds.length; i++) {
rdds[i] = (JavaRDD) context.getRDD(pcs.get(i));
}
JavaRDD rdd = context.getSparkContext().union(rdds);
context.setRDD(transform, rdd);
}
};

private static final Map<Class, TransformEvaluator> EVALUATORS = ImmutableMap.<Class, TransformEvaluator>builder()
.put(Combine.GroupedValues.class, GROUPED)
.put(GroupByKey.class, GBK)
.put(ParDo.Bound.class, PARDO)
.put(TextIO.Read.Bound.class, READ_TEXT)
.put(TextIO.Write.Bound.class, WRITE_TEXT)
.put(Create.class, CREATE)
.put(Flatten.class, FLATTEN)
.build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
*/
package com.cloudera.dataflow.spark;

import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.PipelineResult;
import com.google.cloud.dataflow.sdk.io.TextIO;
import com.google.cloud.dataflow.sdk.transforms.Count;
Expand Down

0 comments on commit cb91f2a

Please sign in to comment.