Skip to content

Commit

Permalink
Module API: fix missing RM_CLIENTINFO_FLAG_SSL. (redis#7666)
Browse files Browse the repository at this point in the history
The `REDISMODULE_CLIENTINFO_FLAG_SSL` flag was already a part of the `RedisModuleClientInfo` structure but was not implemented.
  • Loading branch information
yossigo authored Aug 17, 2020
1 parent fb2a94a commit 64c360c
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ static ssize_t connSocketSyncReadLine(connection *conn, char *ptr, ssize_t size,
return syncReadLine(conn->fd, ptr, size, timeout);
}

static int connSocketGetType(connection *conn) {
(void) conn;

return CONN_TYPE_SOCKET;
}

ConnectionType CT_Socket = {
.ae_handler = connSocketEventHandler,
Expand All @@ -343,7 +348,8 @@ ConnectionType CT_Socket = {
.blocking_connect = connSocketBlockingConnect,
.sync_write = connSocketSyncWrite,
.sync_read = connSocketSyncRead,
.sync_readline = connSocketSyncReadLine
.sync_readline = connSocketSyncReadLine,
.get_type = connSocketGetType
};


Expand Down
9 changes: 9 additions & 0 deletions src/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ typedef enum {
#define CONN_FLAG_CLOSE_SCHEDULED (1<<0) /* Closed scheduled by a handler */
#define CONN_FLAG_WRITE_BARRIER (1<<1) /* Write barrier requested */

#define CONN_TYPE_SOCKET 1
#define CONN_TYPE_TLS 2

typedef void (*ConnectionCallbackFunc)(struct connection *conn);

typedef struct ConnectionType {
Expand All @@ -64,6 +67,7 @@ typedef struct ConnectionType {
ssize_t (*sync_write)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
int (*get_type)(struct connection *conn);
} ConnectionType;

struct connection {
Expand Down Expand Up @@ -194,6 +198,11 @@ static inline ssize_t connSyncReadLine(connection *conn, char *ptr, ssize_t size
return conn->type->sync_readline(conn, ptr, size, timeout);
}

/* Return CONN_TYPE_* for the specified connection */
static inline int connGetType(connection *conn) {
return conn->type->get_type(conn);
}

connection *connCreateSocket();
connection *connCreateAcceptedSocket(int fd);

Expand Down
2 changes: 2 additions & 0 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,8 @@ int modulePopulateClientInfoStructure(void *ci, client *client, int structver) {
ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_TRACKING;
if (client->flags & CLIENT_BLOCKED)
ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_BLOCKED;
if (connGetType(client->conn) == CONN_TYPE_TLS)
ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_SSL;

int port;
connPeerToString(client->conn,ci1->addr,sizeof(ci1->addr),&port);
Expand Down
7 changes: 7 additions & 0 deletions src/tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,12 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
return nread;
}

static int connTLSGetType(connection *conn_) {
(void) conn_;

return CONN_TYPE_TLS;
}

ConnectionType CT_TLS = {
.ae_handler = tlsEventHandler,
.accept = connTLSAccept,
Expand All @@ -837,6 +843,7 @@ ConnectionType CT_TLS = {
.sync_write = connTLSSyncWrite,
.sync_read = connTLSSyncRead,
.sync_readline = connTLSSyncReadLine,
.get_type = connTLSGetType
};

int tlsHasPendingData() {
Expand Down
38 changes: 38 additions & 0 deletions tests/modules/misc.c
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,42 @@ int test_setlfu(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_OK;
}

int test_clientinfo(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
{
(void) argv;
(void) argc;

RedisModuleClientInfo ci = { .version = REDISMODULE_CLIENTINFO_VERSION };

if (RedisModule_GetClientInfoById(&ci, RedisModule_GetClientId(ctx)) == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, "failed to get client info");
return REDISMODULE_OK;
}

RedisModule_ReplyWithArray(ctx, 10);
char flags[512];
snprintf(flags, sizeof(flags) - 1, "%s:%s:%s:%s:%s:%s",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_SSL ? "ssl" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_PUBSUB ? "pubsub" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_BLOCKED ? "blocked" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_TRACKING ? "tracking" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_UNIXSOCKET ? "unixsocket" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_MULTI ? "multi" : "");

RedisModule_ReplyWithCString(ctx, "flags");
RedisModule_ReplyWithCString(ctx, flags);
RedisModule_ReplyWithCString(ctx, "id");
RedisModule_ReplyWithLongLong(ctx, ci.id);
RedisModule_ReplyWithCString(ctx, "addr");
RedisModule_ReplyWithCString(ctx, ci.addr);
RedisModule_ReplyWithCString(ctx, "port");
RedisModule_ReplyWithLongLong(ctx, ci.port);
RedisModule_ReplyWithCString(ctx, "db");
RedisModule_ReplyWithLongLong(ctx, ci.db);

return REDISMODULE_OK;
}

int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
REDISMODULE_NOT_USED(argv);
REDISMODULE_NOT_USED(argc);
Expand All @@ -221,6 +257,8 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_ERR;
if (RedisModule_CreateCommand(ctx,"test.getlfu", test_getlfu,"",0,0,0) == REDISMODULE_ERR)
return REDISMODULE_ERR;
if (RedisModule_CreateCommand(ctx,"test.clientinfo", test_clientinfo,"",0,0,0) == REDISMODULE_ERR)
return REDISMODULE_ERR;

return REDISMODULE_OK;
}
19 changes: 19 additions & 0 deletions tests/unit/moduleapi/misc.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,23 @@ start_server {tags {"modules"}} {
assert { $was_set == 0 }
}

test {test module clientinfo api} {
# Test basic sanity and SSL flag
set info [r test.clientinfo]
set ssl_flag [expr $::tls ? {"ssl:"} : {":"}]

assert { [dict get $info db] == 9 }
assert { [dict get $info flags] == "${ssl_flag}::::" }

# Test MULTI flag
r multi
r test.clientinfo
set info [lindex [r exec] 0]
assert { [dict get $info flags] == "${ssl_flag}::::multi" }

# Test TRACKING flag
r client tracking on
set info [r test.clientinfo]
assert { [dict get $info flags] == "${ssl_flag}::tracking::" }
}
}

0 comments on commit 64c360c

Please sign in to comment.