Skip to content

Commit

Permalink
🔨 refactor(layer_2d): expose indices in VQ2D (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud authored Jun 10, 2023
1 parent 84de76b commit d039cdc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.

## [unreleased]

🔨 **layer_2d:** expose indices in VQ2D ([#99](https://github.com/owkin/GrAIdient/pull/99))\
🪜 **layer_2d:** loosen range contraint in ColorJitterHSV ([#98](https://github.com/owkin/GrAIdient/pull/98))\
🪜 **layer_2d:** SimilarityError2D & dirty losses ([#97](https://github.com/owkin/GrAIdient/pull/97))\
🔨 **core:** LayerWeightInit ([#96](https://github.com/owkin/GrAIdient/pull/96))\
Expand Down
68 changes: 34 additions & 34 deletions Sources/GrAIdient/Layer2D/VQ2D.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Foundation
public class VQ2D: Layer2D, LayerWeightInit
{
/// The number of vector approximations.
let _K: Int
public let K: Int

/// Coefficient for commitment.
public var beta: Double
Expand All @@ -20,7 +20,7 @@ public class VQ2D: Layer2D, LayerWeightInit
/// Indices of maximal elements.
/// Shape ~ (batch, height, width).
///
var _indices: MetalBuffer<Int32>! = nil
public var indices: MetalBuffer<Int32>! = nil

///
/// Grid of weights.
Expand Down Expand Up @@ -59,7 +59,7 @@ public class VQ2D: Layer2D, LayerWeightInit
}

var weightsTmp = [Float]()
for k in 0..<_K {
for k in 0..<K {
for depth in 0..<nbChannels
{
weightsTmp.append(Float(_wArrays.w(k, depth)))
Expand Down Expand Up @@ -98,7 +98,7 @@ public class VQ2D: Layer2D, LayerWeightInit
public var connectivityIO: (Int, Int)
{
get {
return (nbChannels, _K)
return (nbChannels, K)
}
}

Expand All @@ -123,7 +123,7 @@ public class VQ2D: Layer2D, LayerWeightInit
beta: Double,
params: GrAI.Model.Params)
{
_K = K
self.K = K
self.beta = beta
super.init(
layerPrev: layerPrev,
Expand All @@ -146,7 +146,7 @@ public class VQ2D: Layer2D, LayerWeightInit
{
let values = try decoder.container(keyedBy: Keys.self)

_K = try values.decode(Int.self, forKey: .K)
K = try values.decode(Int.self, forKey: .K)
beta = try values.decode(Double.self, forKey: .beta)

try super.init(from: decoder)
Expand All @@ -170,7 +170,7 @@ public class VQ2D: Layer2D, LayerWeightInit
{
var container = encoder.container(keyedBy: Keys.self)

try container.encode(_K, forKey: .K)
try container.encode(K, forKey: .K)
try container.encode(beta, forKey: .beta)

let weightsList: [Float]
Expand Down Expand Up @@ -210,7 +210,7 @@ public class VQ2D: Layer2D, LayerWeightInit
params.context.curID = id

let layer = VQ2D(
layerPrev: layerPrev, K: _K, beta: beta, params: params
layerPrev: layerPrev, K: K, beta: beta, params: params
)
if inPlace
{
Expand Down Expand Up @@ -242,7 +242,7 @@ public class VQ2D: Layer2D, LayerWeightInit
{
super.resetKernelCPU()
_wArrays?.reset()
_indices = nil
indices = nil
}

///
Expand All @@ -256,7 +256,7 @@ public class VQ2D: Layer2D, LayerWeightInit
{
super.resetKernelGPU()

_indices = nil
indices = nil
_wDeltaWeights = nil
_wBuffers?.reset()
}
Expand All @@ -273,9 +273,9 @@ public class VQ2D: Layer2D, LayerWeightInit
_weightsList = generateWeightsList()
}

_wArrays = WeightGrids(width: nbChannels, height: _K)
_wArrays = WeightGrids(width: nbChannels, height: K)

for k in 0..<_K {
for k in 0..<K {
for depth in 0..<nbChannels
{
let offset = depth + nbChannels * k
Expand All @@ -297,12 +297,12 @@ public class VQ2D: Layer2D, LayerWeightInit
}

_wBuffers = WeightBuffers(
nbElems: _K * nbChannels,
nbElems: K * nbChannels,
deviceID: deviceID
)

let weightsPtr = _wBuffers.w_p!.shared.buffer
for elem in 0..<_K * nbChannels
for elem in 0..<K * nbChannels
{
weightsPtr[elem] = _weightsList[elem]
}
Expand All @@ -321,9 +321,9 @@ public class VQ2D: Layer2D, LayerWeightInit
{
try super.checkStateCPU(batchSize: batchSize)

if _indices == nil
if indices == nil
{
_indices = MetalSharedBuffer<Int32>(
indices = MetalSharedBuffer<Int32>(
batchSize * height * width,
deviceID: deviceID
)
Expand All @@ -344,13 +344,13 @@ public class VQ2D: Layer2D, LayerWeightInit
GrAI.Gradient.sample && _wDeltaWeights == nil
{
_wDeltaWeights = MetalPrivateBuffer<Float>(
batchSize * _K * nbChannels, deviceID: deviceID
batchSize * K * nbChannels, deviceID: deviceID
)
}

if _indices == nil
if indices == nil
{
_indices = MetalPrivateBuffer<Int32>(
indices = MetalPrivateBuffer<Int32>(
batchSize * height * width,
deviceID: deviceID
)
Expand All @@ -369,7 +369,7 @@ public class VQ2D: Layer2D, LayerWeightInit
try checkStateCPU(batchSize: batchSize)

let neuronsPrev = layerPrev.neurons
let indicesPtr = (_indices as! MetalSharedBuffer<Int32>).buffer
let indicesPtr = (indices as! MetalSharedBuffer<Int32>).buffer

for elem in 0..<batchSize {
for i in 0..<height {
Expand All @@ -378,7 +378,7 @@ public class VQ2D: Layer2D, LayerWeightInit
var minIndex = -1
var minValue: Double? = nil

for k in 0..<_K
for k in 0..<K
{
var value: Double = 0.0
for depth in 0..<nbChannels
Expand Down Expand Up @@ -425,7 +425,7 @@ public class VQ2D: Layer2D, LayerWeightInit
let pNbChannels: [UInt32] = [UInt32(nbChannels)]
let pNbBatch: [UInt32] = [UInt32(batchSize)]
let pDimensions: [UInt32] = [UInt32(width), UInt32(height)]
let pK: [UInt32] = [UInt32(_K)]
let pK: [UInt32] = [UInt32(K)]

let command = MetalKernel.get.createCommand(
"vq2DForward", deviceID: deviceID
Expand All @@ -437,7 +437,7 @@ public class VQ2D: Layer2D, LayerWeightInit
command.setBytes(pK, atIndex: 4)
command.setBytes(pNbBatch, atIndex: 5)
command.setBuffer(outs.metal, atIndex: 6)
command.setBuffer(_indices.metal, atIndex: 7)
command.setBuffer(indices.metal, atIndex: 7)

command.dispatchThreads(
width: height * width,
Expand All @@ -459,7 +459,7 @@ public class VQ2D: Layer2D, LayerWeightInit
if let layerPrev = self.layerPrev as? Layer2D, mustComputeBackward
{
let neuronsPrev = layerPrev.neurons
let indicesPtr = (_indices as! MetalSharedBuffer<Int32>).buffer
let indicesPtr = (indices as! MetalSharedBuffer<Int32>).buffer

for elem in 0..<batchSize {
for i in 0..<height {
Expand Down Expand Up @@ -498,11 +498,11 @@ public class VQ2D: Layer2D, LayerWeightInit
{
let neuronsPrev = layerPrev.neurons
let coeff = batchSize * height * width
let indicesPtr = (_indices as! MetalSharedBuffer<Int32>).buffer
let indicesPtr = (indices as! MetalSharedBuffer<Int32>).buffer

if !accumulateDeltaWeights
{
for k in 0..<_K {
for k in 0..<K {
for depth in 0..<nbChannels
{
_wArrays.g(k, depth, 0.0)
Expand Down Expand Up @@ -550,7 +550,7 @@ public class VQ2D: Layer2D, LayerWeightInit
let pNbChannels: [UInt32] = [UInt32(nbChannels)]
let pNbBatch: [UInt32] = [UInt32(batchSize)]
let pDimensions: [UInt32] = [UInt32(width), UInt32(height)]
let pK: [UInt32] = [UInt32(_K)]
let pK: [UInt32] = [UInt32(K)]
let pBeta: [Float] = [Float(beta)]
let pDirty: [UInt32] = layerPrev.dirty ? [1] : [0]

Expand All @@ -560,7 +560,7 @@ public class VQ2D: Layer2D, LayerWeightInit
command.setBuffer(layerPrev.outs.metal, atIndex: 0)
command.setBuffer(delta.metal, atIndex: 1)
command.setBuffer(_wBuffers.w.metal, atIndex: 2)
command.setBuffer(_indices.metal, atIndex: 3)
command.setBuffer(indices.metal, atIndex: 3)
command.setBytes(pNbChannels, atIndex: 4)
command.setBytes(pDimensions, atIndex: 5)
command.setBytes(pK, atIndex: 6)
Expand All @@ -586,7 +586,7 @@ public class VQ2D: Layer2D, LayerWeightInit
let pNbChannels: [UInt32] = [UInt32(nbChannels)]
let pNbBatch: [UInt32] = [UInt32(batchSize)]
let pDimensions: [UInt32] = [UInt32(width), UInt32(height)]
let pK: [UInt32] = [UInt32(_K)]
let pK: [UInt32] = [UInt32(K)]
let pAccumulate: [UInt32] = accumulateDeltaWeights ? [1] : [0]

var command: MetalCommand
Expand Down Expand Up @@ -615,14 +615,14 @@ public class VQ2D: Layer2D, LayerWeightInit
)
command.setBuffer(layerPrev.outs.metal, atIndex: 0)
command.setBuffer(_wBuffers.w.metal, atIndex: 1)
command.setBuffer(_indices.metal, atIndex: 2)
command.setBuffer(indices.metal, atIndex: 2)
command.setBytes(pNbChannels, atIndex: 3)
command.setBytes(pDimensions, atIndex: 4)
command.setBytes(pK, atIndex: 5)
command.setBytes(pNbBatch, atIndex: 6)
command.setBuffer(_wBuffers.g.metal, atIndex: 7)

command.dispatchThreads(width: nbChannels, height: _K)
command.dispatchThreads(width: nbChannels, height: K)
command.enqueue()
}
else
Expand All @@ -647,7 +647,7 @@ public class VQ2D: Layer2D, LayerWeightInit
)
command.setBuffer(layerPrev.outs.metal, atIndex: 0)
command.setBuffer(_wBuffers.w.metal, atIndex: 1)
command.setBuffer(_indices.metal, atIndex: 2)
command.setBuffer(indices.metal, atIndex: 2)
command.setBytes(pNbChannels, atIndex: 3)
command.setBytes(pDimensions, atIndex: 4)
command.setBytes(pK, atIndex: 5)
Expand All @@ -656,7 +656,7 @@ public class VQ2D: Layer2D, LayerWeightInit

command.dispatchThreads(
width: nbChannels,
height: batchSize * _K
height: batchSize * K
)
command.enqueue()

Expand All @@ -673,7 +673,7 @@ public class VQ2D: Layer2D, LayerWeightInit
command.setBytes(pAccumulate, atIndex: 4)
command.setBuffer(_wBuffers.g.metal, atIndex: 5)

command.dispatchThreads(width: nbChannels, height: _K)
command.dispatchThreads(width: nbChannels, height: K)
command.enqueue()
}
}
Expand Down

0 comments on commit d039cdc

Please sign in to comment.