账号密码登录
微信安全登录
微信扫描二维码登录

登录后绑定QQ、微信即可实现信息互通

手机验证码登录
找回密码返回
邮箱找回 手机找回
注册账号返回
其他登录方式
分享
  • 收藏
    X
    tensorflow实现mnist手写数字识别报错?
    31
    0

    代码:
    import tensorflow as tf
    import numpy as np
    tf.enable_eager_execution()

    class DataLoader():

    def __init__(self):
        mnist = tf.keras.datasets.mnist.load_data(path = 'mnist.npz')
    
        self.train_data = mnist[0][0]
        self.train_data = np.reshape(self.train_data,(self.train_data.shape[0],28*28))
        self.train_labels = mnist[0][1]
        self.eval_data = mnist[1][0]
        self.train_data = np.reshape(self.train_data,(self.train_data.shape[0],28*28))
        self.eval_labels = mnist[1][1]
    
    def get_batch(self, batch_size):
        indexs = np.random.randint(0,self.train_data.shape[0],batch_size)
        return self.train_data[indexs, :], self.train_labels[indexs]
    

    '''
    class MLP(tf.keras.Modle):
    '''
    class MLP(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(units=100, activation= tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10,activation =None)
    
    def call(self, inputs):
        x = self.dense1(inputs)
        y = self.dense2(x)
        return y
    
    def predict(self, inputs):
        logits = self(inputs)
        return tf.argmax(logits, axis=-1)
    
    
    
    

    num_batches = 1000
    batch_size = 50
    learning_rate = 0.001

    model = MLP()
    data_loader = DataLoader()
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    for batch_index in range(num_batches):

    X , y = data_loader.get_batch(batch_size)
    print(np.shape(X))
    with tf.GradientTape() as tape:
        X = tf.convert_to_tensor(X, dtype = tf.int64, name = 'X')
        print(X)
        y_logit_pred = model(X)
        loss = tf.losses.sparse_softmax_cross_entropy(labels = y, logits = y_logit_pred)
        print('batch %d: loss %f' % (batch_index, loss.numpy()))
    
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars = zip(grads, model.variables))
    

    num_eval_samples = np.shape(data_loader.eval_labels)[0]
    y_pred = model.predict(data_loader.eval_data).numpy()
    print("test accuracy: %f" % (sum(y_pred == data_loader.eval_labels) / num_eval_samples))

    错误信息:

    /home/kalarea/.conda/envs/py35/bin/python /home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py
    /home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from float to np.floating is deprecated. In future, it will be treated as np.float64 == np.dtype(float).type.
    from ._conv import register_converters as _register_converters
    (50, 784)
    2018-10-14 18:28:18.977966: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
    tf.Tensor(
    [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]], shape=(50, 784), dtype=int64)
    Traceback (most recent call last):
    File "/home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py", line 55, in <module>

    y_logit_pred = model(X)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 769, in call

    outputs = self.call(inputs, *args, **kwargs)

    File "/home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py", line 30, in call

    x = self.dense1(inputs)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 759, in call

    self.build(input_shapes)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/layers/core.py", line 921, in build

    trainable=True)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 586, in add_weight

    aggregation=aggregation)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/checkpointable/base.py", line 591, in _add_variable_with_custom_getter

    **kwargs_for_getter)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1986, in make_variable

    aggregation=aggregation)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 145, in call

    return cls._variable_call(*args, **kwargs)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 141, in _variable_call

    aggregation=aggregation)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 120, in <lambda>

    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 2434, in default_variable_creator

    import_scope=import_scope)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 147, in call

    return super(VariableMetaclass, cls).__call__(*args, **kwargs)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 297, in init

    constraint=constraint)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 420, in _init_from_args

    initial_value = initial_value()

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1970, in <lambda>

    shape, dtype=dtype, partition_info=partition_info)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/init_ops.py", line 483, in call

    shape, -limit, limit, dtype, seed=self.seed)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/random_ops.py", line 240, in random_uniform

    shape, minval, maxval, seed=seed1, seed2=seed2, name=name)

    File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gen_random_ops.py", line 848, in random_uniform_int

    _six.raise_from(_core._status_to_exception(e.code, message), None)

    File "<string>", line 3, in raise_from
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Need minval < maxval, got 0 >= 0 [Op:RandomUniformInt] name: mlp/dense/kernel/random_uniform/

    Process finished with exit code 1

    1
    打赏
    收藏
    点击回答
        全部回答
    • 0
    • 一骑轻尘 普通会员 1楼
      502 Bad Gateway

      502 Bad Gateway


      nginx
    更多回答
    扫一扫访问手机版
    • 回到顶部
    • 回到顶部