129 lines
5.6 KiB
Plaintext
129 lines
5.6 KiB
Plaintext
|
"""
|
|||
|
https://stepik.org/lesson/701989/step/11?unit=702090
|
|||
|
|
|||
|
В нейронных сетях использую операцию под названием Max Pooling.
|
|||
|
Суть ее состоит в сканировании прямоугольной таблицы чисел (матрицы)
|
|||
|
окном определенного размера (обычно, 2x2 элемента) и выбора наибольшего значения в пределах этого окна.
|
|||
|
Или, если окна выходят за пределы матрицы, то они пропускаются (игнорируются).
|
|||
|
|
|||
|
|
|||
|
Мы повторим эту процедуру. Для этого в программе нужно объявить класс с именем MaxPooling, объекты которого создаются командой:
|
|||
|
|
|||
|
mp = MaxPooling(step=(2, 2), size=(2,2))
|
|||
|
где step - шаг смещения окна по горизонтали и вертикали; size - размер окна по горизонтали и вертикали.
|
|||
|
|
|||
|
Параметры step и size по умолчанию должны принимать кортеж со значениями (2, 2).
|
|||
|
|
|||
|
Для выполнения операции Max Pooling используется команда:
|
|||
|
|
|||
|
res = mp(matrix)
|
|||
|
где matrix - прямоугольная таблица чисел; res - ссылка на результат обработки таблицы matrix (должна создаваться новая таблица чисел.
|
|||
|
|
|||
|
Прямоугольную таблицу чисел следует описывать вложенными списками. Если при сканировании таблицы часть окна выходит за ее пределы, то эти данные отбрасывать (не учитывать все окно).
|
|||
|
|
|||
|
Если matrix не является прямоугольной таблицей или содержит хотя бы одно не числовое значение, то должно генерироваться исключение командой:
|
|||
|
|
|||
|
raise ValueError("Неверный формат для первого параметра matrix.")
|
|||
|
Пример использования класса (эти строчки в программе писать не нужно):
|
|||
|
|
|||
|
mp = MaxPooling(step=(2, 2), size=(2,2))
|
|||
|
res = mp([[1, 2, 3, 4], [5, 6, 7, 8], [9, 8, 7, 6], [5, 4, 3, 2]]) # [[6, 8], [9, 7]]
|
|||
|
Результатом будет таблица чисел:
|
|||
|
|
|||
|
6 8
|
|||
|
9 7
|
|||
|
|
|||
|
P.S. В программе достаточно объявить только класс. Выводить на экран ничего не нужно.
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
|
|||
|
class MaxPooling:
|
|||
|
def __init__(self, step=(2, 2), size=(2, 2)):
|
|||
|
self.step = step
|
|||
|
self.size = size
|
|||
|
|
|||
|
def validate_matrix(self, matrix):
|
|||
|
if (
|
|||
|
any(not isinstance(row, list) for row in matrix)
|
|||
|
or any(not isinstance(num, (int, float)) for row in matrix for num in row)
|
|||
|
or len(set(len(row) for row in matrix)) != 1
|
|||
|
):
|
|||
|
raise ValueError("Неверный формат для первого параметра matrix.")
|
|||
|
|
|||
|
def __call__(self, matrix):
|
|||
|
self.validate_matrix(matrix)
|
|||
|
N, M = len(matrix), len(matrix[0])
|
|||
|
step_x, step_y, size_x, size_y = (*self.step, *self.size)
|
|||
|
return [
|
|||
|
[
|
|||
|
max(
|
|||
|
matrix[x][y]
|
|||
|
for x in range(i, i + size_x)
|
|||
|
for y in range(j, j + size_y)
|
|||
|
)
|
|||
|
for j in range(0, M - size_y + 1, step_y)
|
|||
|
]
|
|||
|
for i in range(0, N - size_x + 1, step_x)
|
|||
|
]
|
|||
|
|
|||
|
|
|||
|
def tests():
|
|||
|
mp = MaxPooling(step=(2, 2), size=(2, 2))
|
|||
|
m1 = [[1, 10, 10], [5, 10, 0], [0, 1, 2]]
|
|||
|
m2 = [[1, 10, 10, 12], [5, 10, 0, -5], [0, 1, 2, 300], [40, -100, 0, 54.5]]
|
|||
|
res1 = mp(m1)
|
|||
|
res2 = mp(m2)
|
|||
|
|
|||
|
assert res1 == [[10]], "неверный результат операции MaxPooling"
|
|||
|
assert res2 == [[10, 12], [40, 300]], "неверный результат операции MaxPooling"
|
|||
|
|
|||
|
mp = MaxPooling(step=(3, 3), size=(2, 2))
|
|||
|
m3 = [[1, 12, 14, 12], [5, 10, 0, -5], [0, 1, 2, 300], [40, -100, 0, 54.5]]
|
|||
|
res3 = mp(m3)
|
|||
|
assert res3 == [
|
|||
|
[12]
|
|||
|
], "неверный результат операции при MaxPooling(step=(3, 3), size=(2,2))"
|
|||
|
|
|||
|
try:
|
|||
|
res = mp([[1, 2], [3, 4, 5]])
|
|||
|
except ValueError:
|
|||
|
assert True
|
|||
|
else:
|
|||
|
assert (
|
|||
|
False
|
|||
|
), "некорректо отработала проверка (или она отсутствует) на не прямоугольную матрицу"
|
|||
|
|
|||
|
try:
|
|||
|
res = mp([[1, 2], [3, "4"]])
|
|||
|
except ValueError:
|
|||
|
assert True
|
|||
|
else:
|
|||
|
assert (
|
|||
|
False
|
|||
|
), "некорректо отработала проверка (или она отсутствует) на не числовые значения в матрице"
|
|||
|
|
|||
|
mp = MaxPooling(step=(1, 1), size=(5, 5))
|
|||
|
res = mp(
|
|||
|
[
|
|||
|
[5, 0, 88, 2, 7, 65],
|
|||
|
[1, 33, 7, 45, 0, 1],
|
|||
|
[54, 8, 2, 38, 22, 7],
|
|||
|
[73, 23, 6, 1, 15, 0],
|
|||
|
[4, 12, 9, 1, 76, 6],
|
|||
|
[0, 15, 10, 8, 11, 78],
|
|||
|
]
|
|||
|
) # [[88, 88], [76, 78]]
|
|||
|
|
|||
|
assert res == [
|
|||
|
[88, 88],
|
|||
|
[76, 78],
|
|||
|
], "неверный результат операции MaxPooling(step=(1, 1), size=(5, 5))"
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
import doctest
|
|||
|
|
|||
|
doctest.testmod()
|
|||
|
tests()
|