Skip to content

Commit

Permalink
Add -tta option
Browse files Browse the repository at this point in the history
The TTA mode:
- 8x slower than normal mode
- improves PSNR +0.1
  • Loading branch information
nagadomi committed Nov 8, 2015
1 parent 4322b63 commit b335f3a
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 15 deletions.
59 changes: 59 additions & 0 deletions lib/reconstruct.lua
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,64 @@ function reconstruct.scale(model, scale, x, block_size)
reconstruct.offset_size(model), block_size)
end
end
local function tta(f, model, x, block_size)
local average = nil
local offset = reconstruct.offset_size(model)
for i = 1, 4 do
local flip_f, iflip_f
if i == 1 then
flip_f = function (a) return a end
iflip_f = function (a) return a end
elseif i == 2 then
flip_f = image.vflip
iflip_f = image.vflip
elseif i == 3 then
flip_f = image.hflip
iflip_f = image.hflip
elseif i == 4 then
flip_f = function (a) return image.hflip(image.vflip(a)) end
iflip_f = function (a) return image.vflip(image.hflip(a)) end
end
for j = 1, 2 do
local tr_f, itr_f
if j == 1 then
tr_f = function (a) return a end
itr_f = function (a) return a end
elseif j == 2 then
tr_f = function(a) return a:transpose(2, 3):contiguous() end
itr_f = function(a) return a:transpose(2, 3):contiguous() end
end
local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
offset, block_size)))
if not average then
average = out
else
average:add(out)
end
end
end
return average:div(8.0)
end
function reconstruct.image_tta(model, x, block_size)
if reconstruct.is_rgb(model) then
return tta(reconstruct.image_rgb, model, x, block_size)
else
return tta(reconstruct.image_y, model, x, block_size)
end
end
function reconstruct.scale_tta(model, scale, x, block_size)
if reconstruct.is_rgb(model) then
local f = function (model, x, offset, block_size)
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
end
return tta(f, model, x, block_size)

else
local f = function (model, x, offset, block_size)
return reconstruct.scale_y(model, scale, x, offset, block_size)
end
return tta(f, model, x, block_size)
end
end

return reconstruct
66 changes: 51 additions & 15 deletions waifu2x.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ local function convert_image(opt)
local x, alpha = image_loader.load_float(opt.i)
local new_x = nil
local t = sys.clock()
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
else
scale_f = reconstruct.scale
image_f = reconstruct.image
end
if opt.o == "(auto)" then
local name = path.basename(opt.i)
local e = path.extension(name)
Expand All @@ -25,14 +33,14 @@ local function convert_image(opt)
if not model then
error("Load Error: " .. model_path)
end
new_x = reconstruct.image(model, x, opt.crop_size)
new_x = image_f(model, x, opt.crop_size)
elseif opt.m == "scale" then
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
local model = torch.load(model_path, "ascii")
if not model then
error("Load Error: " .. model_path)
end
new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
new_x = scale_f(model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" then
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
local noise_model = torch.load(noise_model_path, "ascii")
Expand All @@ -45,34 +53,61 @@ local function convert_image(opt)
if not scale_model then
error("Load Error: " .. scale_model_path)
end
x = reconstruct.image(noise_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
x = image_f(noise_model, x)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
image_loader.save_png(opt.o, new_x, alpha, opt.depth)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
local function convert_frames(opt)
local noise1_model, noise2_model, scale_model
local model_path, noise1_model, noise2_model, scale_model
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
else
scale_f = reconstruct.scale
image_f = reconstruct.image
end
if opt.m == "scale" then
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
scale_model = torch.load(model_path, "ascii")
if not scale_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 1 then
local model_path = path.join(opt.model_dir, "noise1_model.t7")
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 2 then
local model_path = path.join(opt.model_dir, "noise2_model.t7")
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise_scale" then
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
scale_model = torch.load(model_path, "ascii")
if not scale_model then
error("Load Error: " .. model_path)
end
if opt.noise_level == 1 then
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.noise_level == 2 then
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
error("Load Error: " .. model_path)
end
end
end
local fp = io.open(opt.l)
if not fp then
Expand All @@ -89,17 +124,17 @@ local function convert_frames(opt)
local x, alpha = image_loader.load_float(lines[i])
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = reconstruct.image(noise1_model, x, opt.crop_size)
new_x = image_f(noise1_model, x, opt.crop_size)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = reconstruct.image(noise2_model, x)
new_x = image_func(noise2_model, x)
elseif opt.m == "scale" then
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = reconstruct.image(noise1_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
x = image_f(noise1_model, x)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = reconstruct.image(noise2_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
x = image_f(noise2_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
Expand Down Expand Up @@ -139,6 +174,7 @@ local function waifu2x()
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-resume", 0, "skip existing files (0|1)")
cmd:option("-thread", -1, "number of CPU threads")
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')

local opt = cmd:parse(arg)
if opt.thread > 0 then
Expand Down

0 comments on commit b335f3a

Please sign in to comment.