-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathlinear.test.ts
40 lines (39 loc) · 1.04 KB
/
linear.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import * as sm from '@shumai/shumai'
import { describe, expect, it } from 'bun:test'
import { isShape } from './utils'
describe('linear', () => {
it('basic construction', () => {
const x = sm.randn([1, 64])
const l = sm.module.linear(64, 128)
const y = l(x)
expect(isShape(y, [1, 128])).toBe(true)
})
it('single sample', () => {
const x = sm.randn([64])
const l = sm.module.linear(64, 128)
const y = l(x)
expect(isShape(y, [128])).toBe(true)
})
it('batch', () => {
const x = sm.randn([37, 64])
const l = sm.module.linear(64, 128)
const y = l(x)
expect(isShape(y, [37, 128])).toBe(true)
})
it('gradient', () => {
const x = sm.randn([2, 64])
x.requires_grad = true
const l = sm.module.linear(64, 128)
const y = l(x)
y.sum().backward()
expect(!!x.grad).toBe(true)
})
it('single sample gradient', () => {
const x = sm.randn([64])
x.requires_grad = true
const l = sm.module.linear(64, 128)
const y = l(x)
y.sum().backward()
expect(!!x.grad).toBe(true)
})
})