Skip to content

Commit

Permalink
multicast support remote publish/subscribe
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudwu committed Apr 29, 2014
1 parent fa6191d commit 5ce6505
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 19 deletions.
2 changes: 2 additions & 0 deletions examples/main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ skynet.start(function()
maxclient = max_client,
})

skynet.newservice("testmulticast")

skynet.exit()
end)
71 changes: 62 additions & 9 deletions lualib-src/lua-multicast.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@
#include <lua.h>
#include <lauxlib.h>
#include <stdint.h>
#include <string.h>

struct mc_package {
int reference;
uint32_t size;
void *data;
};

static int
pack(lua_State *L, void *data, size_t size) {
struct mc_package * pack = skynet_malloc(sizeof(struct mc_package));
pack->reference = 0;
pack->size = (uint32_t)size;
pack->data = data;
struct mc_package ** ret = skynet_malloc(sizeof(*ret));
*ret = pack;
lua_pushlightuserdata(L, ret);
lua_pushinteger(L, sizeof(ret));
return 2;
}

/*
lightuserdata
integer size
Expand All @@ -23,15 +37,37 @@ mc_packlocal(lua_State *L) {
if (size != (uint32_t)size) {
return luaL_error(L, "Size should be 32bit integer");
}
struct mc_package * pack = skynet_malloc(sizeof(struct mc_package));
pack->reference = 0;
pack->size = (uint32_t)size;
pack->data = data;
struct mc_package ** ret = skynet_malloc(sizeof(*ret));
*ret = pack;
lua_pushlightuserdata(L, ret);
lua_pushinteger(L, sizeof(ret));
return 2;
return pack(L, data, size);
}

/*
lightuserdata
integer size
return lightuserdata, sizeof(struct mc_package *)
*/
static int
mc_packremote(lua_State *L) {
void * data = lua_touserdata(L, 1);
size_t size = luaL_checkunsigned(L, 2);
if (size != (uint32_t)size) {
return luaL_error(L, "Size should be 32bit integer");
}
void * msg = skynet_malloc(size);
memcpy(msg, data, size);
return pack(L, msg, size);
}

static int
mc_packstring(lua_State *L) {
size_t size;
const char * msg = luaL_checklstring(L, 1, &size);
if (size != (uint32_t)size) {
return luaL_error(L, "string is too long");
}
void * data = skynet_malloc(size);
memcpy(data, msg, size);
return pack(L, data, size);
}

/*
Expand Down Expand Up @@ -89,13 +125,30 @@ mc_closelocal(lua_State *L) {
return 0;
}

/*
lightuserdata struct mc_package **
return lightuserdata/size
*/
static int
mc_remote(lua_State *L) {
struct mc_package **ptr = lua_touserdata(L,1);
struct mc_package *pack = *ptr;
lua_pushlightuserdata(L, pack->data);
lua_pushunsigned(L, pack->size);
skynet_free(pack);
return 2;
}

int
luaopen_multicast_c(lua_State *L) {
luaL_Reg l[] = {
{ "pack", mc_packlocal },
{ "unpack", mc_unpacklocal },
{ "bind", mc_bindrefer },
{ "close", mc_closelocal },
{ "remote", mc_remote },
{ "packstring", mc_packstring },
{ "packremote", mc_packremote },
{ NULL, NULL },
};
luaL_checkversion(L);
Expand Down
2 changes: 1 addition & 1 deletion service/bootstrap.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ skynet.start(function()
if skynet.getenv "standalone" then
local datacenter = assert(skynet.newservice "datacenterd")
skynet.name("DATACENTER", datacenter)
local smgr = assert(skynet.newservice "service_mgr")
end
assert(skynet.newservice "service_mgr")
skynet.uniqueservice("multicastd")
assert(skynet.newservice(skynet.getenv "start" or "main"))
skynet.exit()
Expand Down
2 changes: 1 addition & 1 deletion service/datacenterd.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ end
skynet.start(function()
skynet.dispatch("lua", function (_, source, cmd, ...)
local f = assert(command[cmd])
skynet.ret(skynet.pack(f(source, ...)))
skynet.ret(skynet.pack(f(...)))
end)
end)
88 changes: 81 additions & 7 deletions service/multicastd.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@ local skynet = require "skynet"
local mc = require "multicast.c"
local datacenter = require "datacenter"

local harbor_id = skynet.harbor(skynet.self())

local command = {}
local channel = {}
local channel_n = {}
local channel_id = skynet.harbor(skynet.self())
local channel_remote = {}
local channel_id = harbor_id

local function get_address(t, id)
local v = assert(datacenter.get("multicast", id))
t[id] = v
return v
end

local node_address = setmetatable({}, { __index = get_address })

function command.NEW()
channel[channel_id] = {}
Expand All @@ -15,36 +26,99 @@ function command.NEW()
return ret
end

function command.PUB(source, c, pack, size)
local function remote_publish(node, channel, source, ...)
skynet.redirect(node_address[node], source, "multicast", channel, ...)
end

local function publish(c , source, pack, size)
local group = assert(channel[c])
mc.bind(pack, channel_n[c])
local msg = skynet.tostring(pack, size)
for k in pairs(group) do
skynet.redirect(k, source, "multicast", c , msg)
end
local remote = channel_remote[c]
if remote then
local _, msg, sz = mc.unpack(pack, size)
local msg = skynet.tostring(msg,sz)
for node in pairs(remote) do
remote_publish(node, c, source, msg)
end
end
end

skynet.register_protocol {
name = "multicast",
id = skynet.PTYPE_MULTICAST,
unpack = function(msg, sz)
return mc.packremote(msg, sz)
end,
dispatch = publish,
}

function command.PUB(source, c, pack, size)
assert(skynet.harbor(source) == harbor_id)
local node = c % 256
if node ~= harbor_id then
-- remote publish
remote_publish(node, c, source, mc.remote(pack))
else
publish(c, source, pack,size)
end
end

function command.SUBR(source, c)
local node = skynet.harbor(source)
assert(node ~= harbor_id)
local group = channel_remote[c]
if group == nil then
group = {}
channel_remote[c] = group
end
group[node] = true
end

function command.SUB(source, c)
local node = c % 256
if node ~= harbor_id then
-- remote group
if channel[c] == nil then
channel[c] = {}
channel_n[c] = 0
skynet.call(node_address[node], "lua", "SUBR", c)
end
end
local group = assert(channel[c])
if not group[source] then
channel_n[c] = channel_n[c] + 1
group[source] = true
end
end

function command.USUBR(source, c)
local node = skynet.harbor(source)
assert(node ~= harbor_id)
local group = assert(channel_remote[c])
group[node] = nil
end

function command.USUB(source, c)
local group = assert(channel[c])
if group[source] then
group[source] = nil
channel_n[c] = channel_n[c] - 1
if channel_n[c] == 0 then
local node = c % 256
if node ~= harbor_id then
-- remote group
channel[c] = nil
channel_n[c] = nil
skynet.call(node_address[node], "lua", "USUBR", c)
end
end
end
end

skynet.register_protocol {
name = "multicast",
id = skynet.PTYPE_MULTICAST,
}

skynet.start(function()
skynet.dispatch("lua", function(_,source, cmd, ...)
local f = assert(command[cmd])
Expand Down
7 changes: 6 additions & 1 deletion test/testmulticast.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local skynet = require "skynet"
local mc = require "multicast"
local dc = require "datacenter"

local mode = ...

Expand All @@ -13,6 +14,7 @@ skynet.start(function()
print(string.format("%s <=== %s (%d)",skynet.address(skynet.self()),skynet.address(source), channel), ...)
end
})
skynet.ret(skynet.pack())
end)
end)

Expand All @@ -23,9 +25,12 @@ skynet.start(function()
print("New channel", channel)
for i=1,10 do
local sub = skynet.newservice("testmulticast", "sub")
skynet.send(sub, "lua", "init", channel)
skynet.call(sub, "lua", "init", channel)
end

print("set channel", channel)
dc.set("CHANNEL", channel)

print(skynet.address(skynet.self()), "===>", channel)
mc.publish(channel, "Hello World")
end)
Expand Down

0 comments on commit 5ce6505

Please sign in to comment.