-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add modules test_rbm and test_rtrbm test_rbm generate some difference images. it is good
- Loading branch information
Showing
7 changed files
with
143 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import utils | ||
|
||
__author__ = 'gavr' | ||
|
||
from rbm import RBM | ||
from rbm import createSimpleRBM | ||
from utils import convertImageToVector | ||
from utils import convertVectorToImage | ||
from utils import saveData | ||
from PIL import Image | ||
from PIL import ImageDraw | ||
from numpy.oldnumeric.random_array import random_integers | ||
import numpy | ||
|
||
def generatorImage(size): | ||
image = Image.new(mode = "P", size = (size, size)) | ||
image.putpalette([255, 255, 255, 0, 0, 0]) | ||
draw = ImageDraw.Draw(image) | ||
f = lambda x, y: random_integers(y, minimum=x) | ||
draw.line((f(1, size/2), f(1, size/2), f(size/2, size), f(size/2, size)), fill = 1) | ||
return image | ||
|
||
def generatorWrongImage(size): | ||
image = Image.new(mode = "P", size = (size, size)) | ||
image.putpalette([255, 255, 255, 0, 0, 0]) | ||
draw = ImageDraw.Draw(image) | ||
f = lambda x, y: random_integers(y, minimum=x) | ||
draw.line((f(size / 2, size), f(1, size / 2), f(1, size / 2), f(size / 2, size)), fill = 1) | ||
return image | ||
|
||
size = 20 | ||
# generate Data | ||
datasize = 2000 | ||
data = [convertImageToVector(generatorImage(size)) for i in range(0, datasize)] | ||
rbm = createSimpleRBM(100, size * size) | ||
#saveData(rbm.saveTo().getvalue()) | ||
#rbm = openRBM(getStringData()) | ||
print 'start train' | ||
|
||
for idx in range(0, 80): | ||
for index in range(0, 20): | ||
print idx, rbm.grad_step(data[index * 100: (index+1) * 100 - 1], numpy.asarray(0.01, dtype='float32'), 20) | ||
|
||
print 'control train data' | ||
|
||
for obj in data: | ||
print rbm.freeEnergy(obj) | ||
|
||
print 'control train data' | ||
|
||
data = [convertImageToVector(generatorImage(size)) for i in range(0, 10)] | ||
|
||
for obj in data: | ||
print rbm.freeEnergy(obj) | ||
|
||
print 'randomInfo' | ||
|
||
for idx in range(0, 5): | ||
x = rbm.generateVisibles() | ||
print rbm.freeEnergy(x) | ||
x1 = rbm.gibbs(x, 1) | ||
print rbm.freeEnergy(x1) | ||
x2 = rbm.gibbs(x, 10) | ||
print rbm.freeEnergy(x2) | ||
|
||
print 'WringImage' | ||
|
||
for idx in range(0, 5): | ||
x = generatorWrongImage(size) | ||
x = convertImageToVector(x) | ||
print rbm.freeEnergy(x) | ||
x1 = rbm.gibbs(x, 1) | ||
print rbm.freeEnergy(x1) | ||
x2 = rbm.gibbs(x, 10) | ||
print rbm.freeEnergy(x2) | ||
|
||
convertVectorToImage(generatorImage(size), rbm.gibbs(convertImageToVector(generatorImage(size)), 1)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbs(convertImageToVector(generatorImage(size)), 5)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbs(convertImageToVector(generatorImage(size)), 10)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbs(convertImageToVector(generatorImage(size)), 20)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbs(convertImageToVector(generatorImage(size)), 30)).show() | ||
|
||
convertVectorToImage(generatorImage(size), rbm.gibbsFromRnd(1)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbsFromRnd(5)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbsFromRnd(10)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbsFromRnd(20)).show() | ||
convertVectorToImage(generatorImage(size), rbm.gibbsFromRnd(30)).show() | ||
|
||
|
||
saveData(rbm.saveTo().getvalue()) | ||
print 'saving has been made' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__author__ = 'gavr' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
__author__ = 'gavr' | ||
|
||
import StringIO | ||
import numpy | ||
|
||
def convertImageToVector(image): | ||
return numpy.asarray(list(image.getdata())) | ||
|
||
def convertVectorToImage(appearance, vector): | ||
im = appearance.copy() | ||
im.putdata(vector) | ||
return im | ||
|
||
# save Data | ||
def saveData(strio): | ||
file = open('data000.txt', 'w') | ||
file.write(strio) | ||
file.close() | ||
|
||
# readData from data.txt | ||
def getStringData(): | ||
file = open('data.txt', 'r') | ||
s = StringIO.StringIO() | ||
output = file.readlines() | ||
s.writelines(output) | ||
file.close() | ||
return s.getvalue() | ||
|
||
|