作者:蛮小蛮将军_415 | 来源:互联网 | 2022-12-02 17:03
出于学习目的,我正在使用Tensorflow.js,并且在尝试将该fit
方法与批处理数据集(10 x 10)一起使用时遇到错误,以了解批处理培训的过程.
我有一些想要分类的图像600x600x3(2个输出,1或0)
这是我的训练循环:
const batches = await loadDataset()
for (let i = 0; i
以下是我定义数据集的方法
const chunks = chunk(examples, BATCH_SIZE)
const batches = chunks.map(
batch => {
const ys = tf.tensor1d(batch.map(e => e.y), 'int32')
const xs = batch
.map(e => imageToInput(e.x, 3))
.reduce((p, c) => p ? p.concat(c) : c)
return { size: batch.length, xs , ys }
}
)
这是模型:
const model = tf.sequential()
model.add(tf.layers.conv2d({
inputShape: [600, 600, 3],
kernelSize: 60,
filters: 50,
strides: 20,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}))
model.add(tf.layers.maxPooling2d({
poolSize: [20, 20],
strides: [20, 20]
}))
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 100,
strides: 20,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}))
model.add(tf.layers.maxPooling2d({
poolSize: [20, 20],
strides: [20, 20]
}))
model.add(tf.layers.flatten())
model.add(tf.layers.dense({
units: 2,
kernelInitializer: 'VarianceScaling',
activation: 'softmax'
}))
我在for循环中的第一次迭代中遇到错误.fit
,如下所示:
Error: new shape and old shape must have the same number of elements.
at Object.assert (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/util.js:36:15)
at reshape_ (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/ops/array_ops.js:271:10)
at Object.reshape (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/ops/operation.js:23:29)
at Tensor.reshape (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/tensor.js:273:26)
at Object.derB [as $b] (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/ops/binary_ops.js:32:24)
at _loop_1 (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/tape.js:90:47)
at Object.backpropagateGradients (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/tape.js:108:9)
at /Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/engine.js:334:20
at /Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/engine.js:91:22
at Engine.scopedRun (/Users/person/nn/node_modules/@tensorflow/tfjs-core/dist/engine.js:101:23)
我不知道从中可以理解什么,并且没有发现任何关于该特定错误的文档或帮助,任何想法?