Arhn - архитектура программирования

TensorFlow Keras MaxPool2D ломает LSTM с потерей CTC?

Я пытаюсь связать слой CNN с двумя слоями LSTM и ctc_batch_cost для потерь, но сталкиваюсь с некоторыми проблемами. Моя модель должна работать с изображениями в градациях серого.

Во время отладки я понял, что если я использую только слой CNN, который сохраняет размер вывода равным размеру ввода + LSTM и CTC, модель может обучаться:

# === Without MaxPool2D ===
inp = Input(name='inp', shape=(128, 32, 1))

cnn = Conv2D(name='conv', filters=1, kernel_size=3, strides=1, padding='same')(inp)

# Go from Bx128x32x1 to Bx128x32 (B x TimeSteps x Features)
rnn_inp = Reshape((128, 32))(maxp)

blstm = Bidirectional(LSTM(256, return_sequences=True), name='blstm1')(rnn_inp)
blstm = Bidirectional(LSTM(256, return_sequences=True), name='blstm2')(blstm)

# Softmax.
dense = TimeDistributed(Dense(80, name='dense'), name='timedDense')(blstm)
rnn_outp = Activation('softmax', name='softmax')(dense)

# Model compiles, calling fit works!

Но когда я добавляю слой MaxPool2D, размеры которого уменьшаются вдвое, я получаю сообщение об ошибке sequence_length(0) <= 64, похожее на представленное здесь.

# === With MaxPool2D ===
inp = Input(name='inp', shape=(128, 32, 1))

cnn = Conv2D(name='conv', filters=1, kernel_size=3, strides=1, padding='same')(inp)
maxp = MaxPool2D(name='maxp', pool_size=2, strides=2, padding='valid')(cnn) # -> 64x16x1

# Go from Bx64x16x1 to Bx64x16 (B x TimeSteps x Features)
rnn_inp = Reshape((64, 16))(maxp)

blstm = Bidirectional(LSTM(256, return_sequences=True), name='blstm1')(rnn_inp)
blstm = Bidirectional(LSTM(256, return_sequences=True), name='blstm2')(blstm)

# Softmax.
dense = TimeDistributed(Dense(80, name='dense'), name='timedDense')(blstm)
rnn_outp = Activation('softmax', name='softmax')(dense)

# Model compiles, but calling fit crashes with:
# InvalidArgumentError: sequence_length(0) <= 64
#    [[{{node ctc_loss_1/CTCLoss}}]]

Ответы:


1

После трехдневной борьбы с этой проблемой я разместил вышеуказанный вопрос здесь, на StackOverflow. Примерно через 2 часа после публикации вопросов я наконец понял это.

TL; DR Решение:

Если вы используете ctc_batch_cost:

Убедитесь, что вы передаете длины (количество временных шагов) последовательностей, входящих в ваши RNN, в качестве входных данных для аргумента input_length.

Если вы используете ctc_loss:

Убедитесь, что вы передаете длины (количество временных шагов) последовательностей, входящих в ваши RNN, в качестве входных данных для аргумента logit_length.

Решение:

Решение кроется в документации, которая, будучи относительно скудной, может быть загадочной для новичка в области машинного обучения, такого как я.

документация TensorFlow для ctc_batch_cost гласит:

tf.keras.backend.ctc_batch_cost(
    y_true, y_pred, input_length, label_length
)

...

Тензор input_length (выборки, 1), содержащий длину последовательности для каждого элемента пакета в y_pred.

...

input_length соответствует logit_length из ctc_loss документации TensorFlow функции:

tf.nn.ctc_loss(
    labels, logits, label_length, logit_length, logits_time_major=True, unique=None,
    blank_index=None, name=None
)

...

logit_length тензор формы [batch_size] Длина входной последовательности в логитах.

...

Вот где он щелкнул по слову logit. Таким образом, аргумент для input_length или logit_length должен быть тензором/контейнером (в моем случае массивом numpy) длин (т.е. количества временных шагов) последовательностей ввод RNN (в моем случае LSTM) в качестве входных данных.

Первоначально я совершал ошибку, считая требуемую длину шириной изображений в градациях серого, которые служат входными данными для всей сети (CNN + MaxPool2D + RNN), но поскольку слой MaxPool2D создает тензор разных размеров для входных данных RNN , функция потери ctc дает сбой.

Теперь фит работает без сбоев.

25.05.2020
Новые материалы

Коллекции публикаций по глубокому обучению
Последние пару месяцев я создавал коллекции последних академических публикаций по различным подполям глубокого обучения в моем блоге https://amundtveit.com - эта публикация дает обзор 25..

Представляем: Pepita
Фреймворк JavaScript с открытым исходным кодом Я знаю, что недостатка в фреймворках JavaScript нет. Но я просто не мог остановиться. Я хотел написать что-то сам, со своими собственными..

Советы по коду Laravel #2
1-) Найти // You can specify the columns you need // in when you use the find method on a model User::find(‘id’, [‘email’,’name’]); // You can increment or decrement // a field in..

Работа с временными рядами спутниковых изображений, часть 3 (аналитика данных)
Анализ временных рядов спутниковых изображений для данных наблюдений за большой Землей (arXiv) Автор: Рольф Симоэс , Жильберто Камара , Жильберто Кейрос , Фелипе Соуза , Педро Р. Андраде ,..

3 способа решить квадратное уравнение (3-й мой любимый) -
1. Методом факторизации — 2. Используя квадратичную формулу — 3. Заполнив квадрат — Давайте поймем это, решив это простое уравнение: Мы пытаемся сделать LHS,..

Создание VR-миров с A-Frame
Виртуальная реальность (и дополненная реальность) стали главными модными терминами в образовательных технологиях. С недорогими VR-гарнитурами, такими как Google Cardboard , и использованием..

Демистификация рекурсии
КОДЕКС Демистификация рекурсии Упрощенная концепция ошеломляющей О чем весь этот шум? Рекурсия, кажется, единственная тема, от которой у каждого начинающего студента-информатика..