Skip to content

Commit

Permalink
feat: fix gflags of UnsafeRowOpt when running in yarn-cluster (4parad…
Browse files Browse the repository at this point in the history
…igm#1731)

* Add log in type codec

* Fix log

* Fix log2

* Add more log

* Add more log

* Add more log2

* Reset log2

* Add more log3

* Add global static method to set gflag of unsafe row opt

* Remove debug code
  • Loading branch information
tobegit3hub authored Apr 28, 2022
1 parent 440d5a4 commit f83eaa7
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 39 deletions.
10 changes: 2 additions & 8 deletions hybridse/include/vm/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,6 @@ class EngineOptions {
/// Return the maximum number of entries we can hold for compiling cache.
inline uint32_t GetMaxSqlCacheSize() const { return max_sql_cache_size_; }

/// Set `true` to enable spark unsafe row format, default `false`.
EngineOptions* SetEnableSparkUnsaferowFormat(bool flag);
/// Return if the engine can support can support spark unsafe row format.
inline bool IsEnableSparkUnsaferowFormat() const {
return enable_spark_unsaferow_format_;
}

/// Return JitOptions
inline hybridse::vm::JitOptions& jit_options() { return jit_options_; }

Expand All @@ -140,7 +133,6 @@ class EngineOptions {
bool enable_batch_window_parallelization_;
bool enable_window_column_pruning_;
uint32_t max_sql_cache_size_;
bool enable_spark_unsaferow_format_;
JitOptions jit_options_;
};

Expand Down Expand Up @@ -367,6 +359,8 @@ class Engine {
/// \brief Initialize LLVM environments
static void InitializeGlobalLLVM();

static void InitializeUnsafeRowOptFlag(bool isUnsafeRowOpt);

~Engine();

/// \brief Compile sql in db and stored the results in the session
Expand Down
2 changes: 0 additions & 2 deletions hybridse/src/codec/type_codec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ int32_t GetStrFieldUnsafe(const int8_t* row, uint32_t col_idx,
// Support Spark UnsafeRow format
if (FLAGS_enable_spark_unsaferow_format) {
// Notice that for UnsafeRowOpt field_offset should be the actual offset of string column

// For Spark UnsafeRow, the first 32 bits is for length and the last 32 bits is for offset.
*size = *(reinterpret_cast<const uint32_t*>(row + field_offset));
uint32_t str_value_offset = *(reinterpret_cast<const uint32_t*>(row + field_offset + 4)) + HEADER_LENGTH;
*data = reinterpret_cast<const char*>(row + str_value_offset);

return 0;
}

Expand Down
15 changes: 5 additions & 10 deletions hybridse/src/vm/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,7 @@ EngineOptions::EngineOptions()
enable_expr_optimize_(true),
enable_batch_window_parallelization_(false),
enable_window_column_pruning_(false),
max_sql_cache_size_(50),
enable_spark_unsaferow_format_(false) {
// TODO(chendihao): Pass the parameter to avoid global gflag
FLAGS_enable_spark_unsaferow_format = enable_spark_unsaferow_format_;
}

EngineOptions* EngineOptions::SetEnableSparkUnsaferowFormat(bool flag) {
enable_spark_unsaferow_format_ = flag;
FLAGS_enable_spark_unsaferow_format = flag;
return this;
max_sql_cache_size_(50) {
}

Engine::Engine(const std::shared_ptr<Catalog>& catalog) : cl_(catalog), options_(), mu_(), lru_cache_() {}
Expand All @@ -73,6 +64,10 @@ void Engine::InitializeGlobalLLVM() {
LLVM_IS_INITIALIZED = true;
}

void Engine::InitializeUnsafeRowOptFlag(bool isUnsafeRowOpt) {
FLAGS_enable_spark_unsaferow_format = isUnsafeRowOpt;
}

bool Engine::GetDependentTables(const std::string& sql, const std::string& db, EngineMode engine_mode,
std::set<std::pair<std::string, std::string>>* db_tables, base::Status& status) {
auto info = std::make_shared<hybridse::vm::SqlCompileInfo>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package com._4paradigm.hybridse.sdk;

import com._4paradigm.hybridse.HybridSeLibrary;
import com._4paradigm.hybridse.vm.Engine;
import com._4paradigm.hybridse.vm.HybridSeJitWrapper;
import com._4paradigm.hybridse.vm.JitOptions;
Expand Down Expand Up @@ -116,11 +115,12 @@ private static synchronized void initModule(String tag, ByteBuffer moduleBuffer)
* @param tag tag specified a jit
* @param moduleBuffer ByteBuffer used to initialize native module
*/
public static synchronized void initJitModule(String tag, ByteBuffer moduleBuffer) {
public static synchronized void initJitModule(String tag, ByteBuffer moduleBuffer, boolean isUnsafeRowOpt) {
// Notice that we should load library before calling this, invoke SqlClusterExecutor.initJavaSdkLibrary()

// ensure worker native
Engine.InitializeGlobalLLVM();
Engine.InitializeUnsafeRowOptFlag(isUnsafeRowOpt);

// ensure worker side module
if (!JitManager.hasModule(tag)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
// Ensure native initialized
SqlClusterExecutor.initJavaSdkLibrary(config.openmldbJsdkLibraryPath)
Engine.InitializeGlobalLLVM()
Engine.InitializeUnsafeRowOptFlag(config.enableUnsafeRowOptimization)

def this(session: SparkSession, sparkAppName: String) = {
this(session, OpenmldbBatchConfig.fromSparkSession(session), sparkAppName)
Expand Down Expand Up @@ -334,10 +335,6 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
logger.info("Disable window parallelization optimization, enable by setting openmldb.window.parallelization")
}

if (config.enableUnsafeRowOptimization) {
engineOptions.SetEnableSparkUnsaferowFormat(true)
}

try {
sqlEngine = new SqlEngine(dbs, engineOptions)
val engine = sqlEngine.getEngine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ object FilterPlan {
outputSchema = filter.fn_info().fn_schema(),
moduleTag = ctx.getTag,
moduleBroadcast = ctx.getSerializableModuleBuffer,
hybridseJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath
hybridseJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath,
ctx.getConf.enableUnsafeRowOptimization
)
ctx.getSparkSession.udf.register(regName, conditionUDF)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ object GroupByAggregationPlan {
val groupKeyComparator = HybridseUtil.createGroupKeyComparator(groupIdxs.toArray)

val openmldbJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath
val isUnafeRowOpt = ctx.getConf.enableUnsafeRowOptimization

// Map partition
val resultRDD = sortedInputDf.rdd.mapPartitions(iter => {
Expand All @@ -93,7 +94,7 @@ object GroupByAggregationPlan {
val tag = projectConfig.moduleTag
val buffer = projectConfig.moduleNoneBroadcast.getBuffer
SqlClusterExecutor.initJavaSdkLibrary(openmldbJsdkLibraryPath)
JitManager.initJitModule(tag, buffer)
JitManager.initJitModule(tag, buffer, isUnafeRowOpt)

val jit = JitManager.getJit(tag)
val fn = jit.FindFunction(projectConfig.functionName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ object JoinPlan {
outputSchema = filter.fn_info().fn_schema(),
moduleTag = ctx.getTag,
moduleBroadcast = ctx.getSerializableModuleBuffer,
hybridseJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath
hybridseJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath,
ctx.getConf.enableUnsafeRowOptimization
)
spark.udf.register(regName, conditionUDF)

Expand Down Expand Up @@ -206,13 +207,15 @@ object JoinPlan {
outputSchema: java.util.List[ColumnDef],
moduleTag: String,
moduleBroadcast: SerializableByteBuffer,
hybridseJsdkLibraryPath: String
hybridseJsdkLibraryPath: String,
isUnsafeRowOpt: Boolean
) extends Function1[Row, Boolean] with Serializable {

@transient private lazy val tls = new ThreadLocal[UnSafeJoinConditionUDFImpl]() {
override def initialValue(): UnSafeJoinConditionUDFImpl = {
new UnSafeJoinConditionUDFImpl(
functionName, inputSchemaSlices, outputSchema, moduleTag, moduleBroadcast, hybridseJsdkLibraryPath)
functionName, inputSchemaSlices, outputSchema, moduleTag, moduleBroadcast, hybridseJsdkLibraryPath,
isUnsafeRowOpt)
}
}

Expand All @@ -226,7 +229,8 @@ object JoinPlan {
outputSchema: java.util.List[ColumnDef],
moduleTag: String,
moduleBroadcast: SerializableByteBuffer,
openmldbJsdkLibraryPath: String
openmldbJsdkLibraryPath: String,
isUnafeRowOpt: Boolean
) extends Function1[Row, Boolean] with Serializable {
private val jit = initJIT()

Expand All @@ -243,7 +247,7 @@ object JoinPlan {
// ensure worker native
val buffer = moduleBroadcast.getBuffer
SqlClusterExecutor.initJavaSdkLibrary(openmldbJsdkLibraryPath)
JitManager.initJitModule(moduleTag, buffer)
JitManager.initJitModule(moduleTag, buffer, isUnafeRowOpt)

JitManager.getJit(moduleTag)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object RowProjectPlan {
val tag = projectConfig.moduleTag
val buffer = projectConfig.moduleNoneBroadcast.getBuffer
SqlClusterExecutor.initJavaSdkLibrary(openmldbJsdkLibraryPath)
JitManager.initJitModule(tag, buffer)
JitManager.initJitModule(tag, buffer, isUnsafeRowOpt)

val jit = JitManager.getJit(tag)
val fn = jit.FindFunction(projectConfig.functionName)
Expand Down Expand Up @@ -171,7 +171,7 @@ object RowProjectPlan {
val tag = projectConfig.moduleTag
val buffer = projectConfig.moduleNoneBroadcast
SqlClusterExecutor.initJavaSdkLibrary(openmldbJsdkLibraryPath)
JitManager.initJitModule(tag, buffer.getBuffer)
JitManager.initJitModule(tag, buffer.getBuffer, isUnsafeRowOpt)

val jit = JitManager.getJit(tag)
val fn = jit.FindFunction(projectConfig.functionName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ object WindowAggPlanUtil {
excludeCurrentTime: Boolean,
needAppendInput: Boolean,
limitCnt: Int,
keepIndexColumn: Boolean)
keepIndexColumn: Boolean,
isUnsafeRowOpt: Boolean)


/** Get the data from context and physical node and create the WindowAggConfig object.
Expand Down Expand Up @@ -198,7 +199,8 @@ object WindowAggPlanUtil {
excludeCurrentTime = node.exclude_current_time(),
needAppendInput = node.need_append_input(),
limitCnt = node.GetLimitCnt(),
keepIndexColumn = keepIndexColumn
keepIndexColumn = keepIndexColumn,
isUnsafeRowOpt = ctx.getConf.enableUnsafeRowOptimization
)
}

Expand All @@ -211,7 +213,7 @@ object WindowAggPlanUtil {
val tag = config.moduleTag
val buffer = config.moduleNoneBroadcast.getBuffer
SqlClusterExecutor.initJavaSdkLibrary(sqlConfig.openmldbJsdkLibraryPath)
JitManager.initJitModule(tag, buffer)
JitManager.initJitModule(tag, buffer, config.isUnsafeRowOpt)

val jit = JitManager.getJit(tag)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ object WindowSampleSupport {
private val jit = {
val buffer = config.moduleNoneBroadcast.getBuffer
SqlClusterExecutor.initJavaSdkLibrary(sqlConfig.openmldbJsdkLibraryPath)
JitManager.initJitModule(config.moduleTag, buffer)
JitManager.initJitModule(config.moduleTag, buffer, sqlConfig.enableUnsafeRowOptimization)
JitManager.getJit(config.moduleTag)
}

Expand Down

0 comments on commit f83eaa7

Please sign in to comment.