forked from PaddlePaddle/Paddle.js
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#166 from JingyuanZhang/master
feat(core): add webgl_pack_out feature
- Loading branch information
Showing
14 changed files
with
262 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
packages/paddlejs-backend-webgl/src/ops/shader/custom/index.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import nhwc_2_nchw from './nhwc_2_nchw'; | ||
import pack_out from './pack_out'; | ||
import unpacked_2_packed from './unpacked_2_packed'; | ||
import packed_2_unpacked from './packed_2_unpacked'; | ||
|
||
export { | ||
nhwc_2_nchw, | ||
pack_out, | ||
unpacked_2_packed, | ||
packed_2_unpacked | ||
}; |
36 changes: 36 additions & 0 deletions
36
packages/paddlejs-backend-webgl/src/ops/shader/custom/nhwc_2_nchw.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
/** | ||
* @file fetch | ||
*/ | ||
|
||
function mainFunc( | ||
{ origin, out }, | ||
{} | ||
) { | ||
return ` | ||
void main() { | ||
ivec4 oPos = getOutputTensorPos(); | ||
// 输出坐标转换为输入坐标 | ||
int sumVal = oPos.a * ${out.channel} | ||
+ oPos.b * ${out.width_shape} * ${out.channel} | ||
+ oPos.g | ||
+ oPos.r * ${out.channel} * ${out.width_shape} * ${out.height_shape}; | ||
ivec4 new_oPos = transferFromNHWCtoNCHW( | ||
sumVal, | ||
${origin.channel}, | ||
${origin.width_shape}, | ||
${origin.height_shape}, | ||
${origin.total_shape} | ||
); | ||
float o = getValueFromTensorPos_origin(new_oPos.r, new_oPos.g, new_oPos.b, new_oPos.a); | ||
setOutput(float(o)); | ||
} | ||
`; | ||
} | ||
export default { | ||
mainFunc, | ||
params: [], | ||
textureFuncConf: { | ||
origin: ['getValueFromTensorPos'] | ||
}, | ||
commonFuncConf: ['transferFromNHWCtoNCHW'] | ||
}; |
61 changes: 61 additions & 0 deletions
61
packages/paddlejs-backend-webgl/src/ops/shader/custom/pack_out.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/** | ||
* @file pack out | ||
*/ | ||
|
||
function mainFunc( | ||
{ origin }, | ||
{} | ||
) { | ||
const width_texture_origin = origin.width_texture; | ||
const height_texture_origin = origin.height_texture; | ||
return ` | ||
vec2 getOriginCoord(float x, float y) { | ||
if (x > float(${width_texture_origin})) { | ||
int num = int(x / float(${width_texture_origin})); | ||
x = mod(x, float(${width_texture_origin})); | ||
y = y + float(num); | ||
} | ||
return vec2(x, y); | ||
} | ||
float getClipedCoordRed(vec2 xy) { | ||
return TEXTURE2D( | ||
texture_origin, | ||
vec2(float(xy.x / float(${width_texture_origin})), float(xy.y / float(${height_texture_origin}))) | ||
).r; | ||
} | ||
// start函数 | ||
void main() { | ||
vec2 outCoord = vCoord.xy * _2d_shape_texture_out; | ||
vec4 out4; | ||
float x = floor(outCoord.x) * 4.0; | ||
float y = floor(outCoord.y) * 4.0 + 0.5; | ||
float x0 = x + 0.5; | ||
float x1 = x + 1.5; | ||
float x2 = x + 2.5; | ||
float x3 = x + 3.5; | ||
vec2 xy0 = getOriginCoord(x0, y); | ||
vec2 xy1 = getOriginCoord(x1, y); | ||
vec2 xy2 = getOriginCoord(x2, y); | ||
vec2 xy3 = getOriginCoord(x3, y); | ||
float r = getClipedCoordRed(xy0); | ||
float g = getClipedCoordRed(xy1); | ||
float b = getClipedCoordRed(xy2); | ||
float a = getClipedCoordRed(xy3); | ||
setPackedOutput(vec4(r, g, b, a)); | ||
} | ||
`; | ||
} | ||
export default { | ||
mainFunc, | ||
params: [], | ||
textureFuncConf: { | ||
origin: ['getValueFromTensorPosPacking', 'getValueFromTensorPos'] | ||
} | ||
}; |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
packages/paddlejs-backend-webgl/test/op/data/pack_out.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
{ | ||
"ops": [ | ||
{ | ||
"attrs": {}, | ||
"inputs": { | ||
"X": ["concat.tmp_0"] | ||
}, | ||
"outputs": { | ||
"Out": ["concat.tmp_2"] | ||
}, | ||
"type": "pack_out" | ||
} | ||
], | ||
"vars": [ | ||
{ | ||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], | ||
"name": "concat.tmp_0", | ||
"shape": [1, 1, 6, 2] | ||
}, | ||
{ | ||
"name": "concat.tmp_2", | ||
"shape": [1, 1, 2, 2] | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/** | ||
* @file nhwc 2 nchw & pack out | ||
*/ | ||
|
||
import env from '../env'; | ||
import { ModelOp, ModelVar } from '../commons/interface'; | ||
import { formatShape } from '../opFactory/utils'; | ||
import Transformer from './transformer'; | ||
|
||
const FINAL_PACK_OP_NAME = 'fetch_pack'; | ||
const FINAL_NCHW_OP_NAME = 'final_nchw'; | ||
|
||
export default class Fetch extends Transformer { | ||
constructor() { | ||
super('Fetch'); | ||
} | ||
|
||
transform(...args: any) { | ||
if (!env.get('webgl_pack_output')) { | ||
return; | ||
} | ||
const [ops, vars] = args; | ||
const fetchOp = ops.find(item => item.type === 'fetch'); | ||
const [inputName] = fetchOp.inputs.X; | ||
const fetchInputVar = vars.find(item => item.name === inputName); | ||
const [n, h, w, c] = formatShape(fetchInputVar.shape); | ||
|
||
// transform data from nhwc to nchw | ||
const nchwOp: ModelOp = { | ||
attrs: {}, | ||
inputs: { | ||
X: [inputName] | ||
}, | ||
outputs: { | ||
Y: [FINAL_NCHW_OP_NAME] | ||
}, | ||
type: 'nhwc_2_nchw' | ||
}; | ||
|
||
// pack out texture | ||
const packOutOp: ModelOp = { | ||
attrs: {}, | ||
inputs: { | ||
X: [FINAL_NCHW_OP_NAME] | ||
}, | ||
outputs: { | ||
Y: [FINAL_PACK_OP_NAME] | ||
}, | ||
type: 'pack_out' | ||
}; | ||
|
||
// make nchw op var | ||
const nchwVar = { | ||
name: FINAL_NCHW_OP_NAME, | ||
shape: [n, c, h, w], | ||
persistable: false | ||
}; | ||
|
||
const pack_width = c * w; | ||
const pack_height = Math.ceil(n * h / 4); | ||
// make pack op var | ||
const packOutVar = { | ||
name: FINAL_PACK_OP_NAME, | ||
shape: [1, 1, pack_height, pack_width], | ||
persistable: false | ||
}; | ||
|
||
const changed_fetch_name = `${inputName}_fetch`; | ||
fetchOp.inputs.X = [changed_fetch_name]; | ||
// save origin fetch op info | ||
const changedFetchVar: ModelVar = { | ||
name: changed_fetch_name, | ||
shape: fetchInputVar.shape, | ||
persistable: false | ||
}; | ||
|
||
ops.push(...[nchwOp, packOutOp]); | ||
vars.push(...[nchwVar, packOutVar, changedFetchVar]); | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters