129 lines
5.6 KiB
Plaintext
129 lines
5.6 KiB
Plaintext
"""
|
||
https://stepik.org/lesson/701989/step/12?unit=702090
|
||
|
||
В нейронных сетях использую операцию под названием Max Pooling.
|
||
Суть ее состоит в сканировании прямоугольной таблицы чисел (матрицы)
|
||
окном определенного размера (обычно, 2x2 элемента) и выбора наибольшего значения в пределах этого окна.
|
||
Или, если окна выходят за пределы матрицы, то они пропускаются (игнорируются).
|
||
|
||
|
||
Мы повторим эту процедуру. Для этого в программе нужно объявить класс с именем MaxPooling, объекты которого создаются командой:
|
||
|
||
mp = MaxPooling(step=(2, 2), size=(2,2))
|
||
где step - шаг смещения окна по горизонтали и вертикали; size - размер окна по горизонтали и вертикали.
|
||
|
||
Параметры step и size по умолчанию должны принимать кортеж со значениями (2, 2).
|
||
|
||
Для выполнения операции Max Pooling используется команда:
|
||
|
||
res = mp(matrix)
|
||
где matrix - прямоугольная таблица чисел; res - ссылка на результат обработки таблицы matrix (должна создаваться новая таблица чисел.
|
||
|
||
Прямоугольную таблицу чисел следует описывать вложенными списками. Если при сканировании таблицы часть окна выходит за ее пределы, то эти данные отбрасывать (не учитывать все окно).
|
||
|
||
Если matrix не является прямоугольной таблицей или содержит хотя бы одно не числовое значение, то должно генерироваться исключение командой:
|
||
|
||
raise ValueError("Неверный формат для первого параметра matrix.")
|
||
Пример использования класса (эти строчки в программе писать не нужно):
|
||
|
||
mp = MaxPooling(step=(2, 2), size=(2,2))
|
||
res = mp([[1, 2, 3, 4], [5, 6, 7, 8], [9, 8, 7, 6], [5, 4, 3, 2]]) # [[6, 8], [9, 7]]
|
||
Результатом будет таблица чисел:
|
||
|
||
6 8
|
||
9 7
|
||
|
||
P.S. В программе достаточно объявить только класс. Выводить на экран ничего не нужно.
|
||
|
||
"""
|
||
|
||
|
||
class MaxPooling:
|
||
def __init__(self, step=(2, 2), size=(2, 2)):
|
||
self.step = step
|
||
self.size = size
|
||
|
||
def validate_matrix(self, matrix):
|
||
if (
|
||
any(not isinstance(row, list) for row in matrix)
|
||
or any(not isinstance(num, (int, float)) for row in matrix for num in row)
|
||
or len(set(len(row) for row in matrix)) != 1
|
||
):
|
||
raise ValueError("Неверный формат для первого параметра matrix.")
|
||
|
||
def __call__(self, matrix):
|
||
self.validate_matrix(matrix)
|
||
N, M = len(matrix), len(matrix[0])
|
||
step_x, step_y, size_x, size_y = (*self.step, *self.size)
|
||
return [
|
||
[
|
||
max(
|
||
matrix[x][y]
|
||
for x in range(i, i + size_x)
|
||
for y in range(j, j + size_y)
|
||
)
|
||
for j in range(0, M - size_y + 1, step_y)
|
||
]
|
||
for i in range(0, N - size_x + 1, step_x)
|
||
]
|
||
|
||
|
||
def tests():
|
||
mp = MaxPooling(step=(2, 2), size=(2, 2))
|
||
m1 = [[1, 10, 10], [5, 10, 0], [0, 1, 2]]
|
||
m2 = [[1, 10, 10, 12], [5, 10, 0, -5], [0, 1, 2, 300], [40, -100, 0, 54.5]]
|
||
res1 = mp(m1)
|
||
res2 = mp(m2)
|
||
|
||
assert res1 == [[10]], "неверный результат операции MaxPooling"
|
||
assert res2 == [[10, 12], [40, 300]], "неверный результат операции MaxPooling"
|
||
|
||
mp = MaxPooling(step=(3, 3), size=(2, 2))
|
||
m3 = [[1, 12, 14, 12], [5, 10, 0, -5], [0, 1, 2, 300], [40, -100, 0, 54.5]]
|
||
res3 = mp(m3)
|
||
assert res3 == [
|
||
[12]
|
||
], "неверный результат операции при MaxPooling(step=(3, 3), size=(2,2))"
|
||
|
||
try:
|
||
res = mp([[1, 2], [3, 4, 5]])
|
||
except ValueError:
|
||
assert True
|
||
else:
|
||
assert (
|
||
False
|
||
), "некорректо отработала проверка (или она отсутствует) на не прямоугольную матрицу"
|
||
|
||
try:
|
||
res = mp([[1, 2], [3, "4"]])
|
||
except ValueError:
|
||
assert True
|
||
else:
|
||
assert (
|
||
False
|
||
), "некорректо отработала проверка (или она отсутствует) на не числовые значения в матрице"
|
||
|
||
mp = MaxPooling(step=(1, 1), size=(5, 5))
|
||
res = mp(
|
||
[
|
||
[5, 0, 88, 2, 7, 65],
|
||
[1, 33, 7, 45, 0, 1],
|
||
[54, 8, 2, 38, 22, 7],
|
||
[73, 23, 6, 1, 15, 0],
|
||
[4, 12, 9, 1, 76, 6],
|
||
[0, 15, 10, 8, 11, 78],
|
||
]
|
||
) # [[88, 88], [76, 78]]
|
||
|
||
assert res == [
|
||
[88, 88],
|
||
[76, 78],
|
||
], "неверный результат операции MaxPooling(step=(1, 1), size=(5, 5))"
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import doctest
|
||
|
||
doctest.testmod()
|
||
tests()
|