Реализация двумерной бикубической интерполяции для dlarray в Matlab

Это дополнительный вопрос для реализации двумерной бикубической интерполяции в Matlab. Пока я работаю над проектированием нейронной сети, я хочу сравнить эффекты между использованием заполнения и использованием линейной / нелинейной интерполяции на нескольких выходах сверточных слоев в скрытых слоях. Из-за бикубической интерполяции не содержится в Method варианты resize2dLayer а также dlresize, Я пытаюсь перенести алгоритм двумерной бикубической интерполяции в Модель Функция. Другими словами, операция изменения размера с бикубической интерполяцией может выполняться в процессе обратного распространения.

Экспериментальная реализация

  • dlBicubicInterpolation реализация функции:

    function [output] = dlBicubicInterpolation(dlArray ,newSize)
        Ndim = size(size(dlArray), 2);
        if Ndim == 2
            output = BicubicInterpolation(dlArray, newSize);
            return;
        end
        if Ndim == 3
            output = dlarray(zeros([newSize size(dlArray, 3)]), dims(dlArray));
            for i = 1:ize(dlArray, 3)
                output(:, :, i) = BicubicInterpolation(dlArray(:, :, i), newSize);
            end
            return;
        end
        if Ndim == 4
            output = dlarray(zeros([newSize size(dlArray, 3) size(dlArray, 4)]), dims(dlArray));
            for i = 1:size(dlArray, 3)
                for j = 1:size(dlArray, 4)
                    output(:, :, i, j) = BicubicInterpolation(dlArray(:, :, i, j), newSize);
                end
            end
            return;
        end
        error("Unsupported case!");
    end
    
  • Другие используемые функции:

    function [output] = BicubicInterpolation(input, newSize)
        originSize = size(input);
        newSizeX = newSize(1);
        newSizeY = newSize(2);
        inputDims = dims(input);
        output = dlarray(zeros([8 8]), inputDims(1:2));
        ratiox = originSize(1) / newSizeX;
        ratioy = originSize(2) / newSizeY;
    
        for y = 0:newSizeY - 1
            for x = 0:newSizeX - 1
                xMappingToOrigin = x * ratiox;
                yMappingToOrigin = y * ratioy;
                xMappingToOriginFloor = floor(xMappingToOrigin);
                yMappingToOriginFloor = floor(yMappingToOrigin);
                xMappingToOriginFrac = xMappingToOrigin - xMappingToOriginFloor;
                yMappingToOriginFrac = yMappingToOrigin - yMappingToOriginFloor;
                ndata = zeros(4, 4);
                for ndatay = -1:2
                    for ndatax = -1:2
                        ndata(ndatax + 2, ndatay + 2) = input( ...
                            clip(xMappingToOriginFloor + ndatax, 0, originSize(1) - 1) + 1, ...
                            clip(yMappingToOriginFloor + ndatay, 0, originSize(2) - 1) + 1);
                    end
                end
                output(x + 1, y + 1) = BicubicPolate(ndata, xMappingToOriginFrac, yMappingToOriginFrac);
            end
        end
    end
    
    function [output] = clip(input, lowerbound, upperbound)
        if (input > upperbound)
            output = upperbound;
            return;
        end
        if (input < lowerbound)
            output = lowerbound;
            return;
        end
        output = input;
    end
    
    function [output] = BicubicPolate(ndata, fracx, fracy)
        x1 = CubicPolate( ndata(1,1), ndata(2,1), ndata(3,1), ndata(4,1), fracx );
        x2 = CubicPolate( ndata(1,2), ndata(2,2), ndata(3,2), ndata(4,2), fracx );
        x3 = CubicPolate( ndata(1,3), ndata(2,3), ndata(3,3), ndata(4,3), fracx );
        x4 = CubicPolate( ndata(1,4), ndata(2,4), ndata(3,4), ndata(4,4), fracx );
    
        output = CubicPolate( x1, x2, x3, x4, fracy );
    end
    
    function [output] = CubicPolate(v0, v1, v2, v3, fracy )
        A = (v3-v2)-(v0-v1);
        B = (v0-v1)-A;
        C = v2-v0;
        D = v1;
        output =  D + fracy * (C + fracy * (B + fracy * A));
    end
    

Полный код тестирования

Ссылаясь на пример на веб-странице Обучение сети с помощью функции модели, линия dlY = dlconv(dlX,weights,bias,'Padding','same'); в первом сверточном слое был изменен на dlY = dlconv(dlX,weights,bias); а потом dlBicubicInterpolation функция используется для выполнения операции изменения размера dlY = dlBicubicInterpolation(dlY, [28 28]); для тестирования.

%% Load Training Data

[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsYTrain = arrayDatastore(YTrain);
dsAnglesTrain = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsYTrain,dsAnglesTrain);

classNames = categories(YTrain);
numClasses = numel(classNames);
numResponses = size(anglesTrain,2);
numObservations = numel(YTrain);
%% 
% View some images from the training data.

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)
%% Define Deep Learning Model
% Define the following network that predicts both labels and angles of rotation.

%% 
% Define and Initialize Model Parameters and State

filterSize = [5 5];
numChannels = 1;
numFilters = 16;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv1.Bias = initializeZeros([numFilters 1]);
%% 
% Initialize the parameters and state for the first batch normalization layer.

parameters.batchnorm1.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm1.Scale = initializeOnes([numFilters 1]);
state.batchnorm1.TrainedMean = zeros(numFilters,1,'single');
state.batchnorm1.TrainedVariance = ones(numFilters,1,'single');
%% 
% Initialize the parameters for the second convolutional layer.

filterSize = [3 3];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv2.Bias = initializeZeros([numFilters 1]);
%% 
% Initialize the parameters and state for the second batch normalization layer.

parameters.batchnorm2.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm2.Scale = initializeOnes([numFilters 1]);
state.batchnorm2.TrainedMean = zeros(numFilters,1,'single');
state.batchnorm2.TrainedVariance = ones(numFilters,1,'single');
%% 
% Initialize the parameters for the third convolutional layer.

filterSize = [3 3];
numChannels = 32;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv3.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv3.Bias = initializeZeros([numFilters 1]);
%% 
% Initialize the parameters and state for the third batch normalization layer.

parameters.batchnorm3.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm3.Scale = initializeOnes([numFilters 1]);
state.batchnorm3.TrainedMean = zeros(numFilters,1,'single');
state.batchnorm3.TrainedVariance = ones(numFilters,1,'single');
%% 
% Initialize the parameters for the convolutional layer in the skip connection.

filterSize = [1 1];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.convSkip.Weights = initializeGlorot(sz,numOut,numIn);
parameters.convSkip.Bias = initializeZeros([numFilters 1]);
%% 
% Initialize the parameters and state for the batch normalization layer in the 
% skip connection.

parameters.batchnormSkip.Offset = initializeZeros([numFilters 1]);
parameters.batchnormSkip.Scale = initializeOnes([numFilters 1]);
state.batchnormSkip.TrainedMean = zeros([numFilters 1],'single');
state.batchnormSkip.TrainedVariance = ones([numFilters 1],'single');
%% 
% Initialize the parameters for the fully connected layer corresponding to the 
% classification output.

sz = [numClasses 6272];
numOut = numClasses;
numIn = 6272;
parameters.fc1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc1.Bias = initializeZeros([numClasses 1]);
%% 
% Initialize the parameters for the fully connected layer corresponding to the 
% regression output.

sz = [numResponses 6272];
numOut = numResponses;
numIn = 6272;
parameters.fc2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc2.Bias = initializeZeros([numResponses 1]);
%% 
% View the struct of the parameters.

parameters
%% 
% View the parameters for the "conv1" operation.

parameters.conv1
%% 
% View the struct of the state.

state
%% 
state.batchnorm1

numEpochs = 20;
miniBatchSize = 128;
%% 

plots = "training-progress";
%% Train Model

mbq = minibatchqueue(dsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB','',''});
%% 
% Initialize parameters for Adam.

trailingAvg = [];
trailingAvgSq = [];
%% 
% Initialize the training progress plot.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end
%% 
% Train the model. 

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    shuffle(mbq)
    
    % Loop over mini-batches
    while hasdata(mbq)
    
        iteration = iteration + 1;
        
        [dlX,dlY1,dlY2] = next(mbq);
              
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,state,loss] = dlfeval(@modelGradients, parameters, dlX, dlY1, dlY2, state);
        
        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end
%% Test Model
[XTest,YTest,anglesTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,'IterationDimension',4);
dsYTest = arrayDatastore(YTest);
dsAnglesTest = arrayDatastore(anglesTest);

dsTest = combine(dsXTest,dsYTest,dsAnglesTest);

mbqTest = minibatchqueue(dsTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB','',''});
%% 

doTraining = false;

classesPredictions = [];
anglesPredictions = [];
classCorr = [];
angleDiff = [];

% Loop over mini-batches.
while hasdata(mbqTest)
    
    % Read mini-batch of data.
    [dlXTest,dlY1Test,dlY2Test] = next(mbqTest);
    
    % Make predictions using the predict function.
    [dlY1Pred,dlY2Pred] = model(parameters,dlXTest,doTraining,state);
    
    % Determine predicted classes.
    Y1PredBatch = onehotdecode(dlY1Pred,classNames,1);
    classesPredictions = [classesPredictions Y1PredBatch];
    
    % Dermine predicted angles
    Y2PredBatch = extractdata(dlY2Pred);
    anglesPredictions = [anglesPredictions Y2PredBatch];
    
    % Compare predicted and true classes
    Y1Test = onehotdecode(dlY1Test,classNames,1);
    classCorr = [classCorr Y1PredBatch == Y1Test];
    
    % Compare predicted and true angles
    angleDiffBatch = Y2PredBatch - dlY2Test;
    angleDiff = [angleDiff extractdata(gather(angleDiffBatch))];
    
end
%% 

accuracy = mean(classCorr)
%% 

angleRMSE = sqrt(mean(angleDiff.^2))
%% 

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = anglesPredictions(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    thetaValidation = anglesTest(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    
    hold off
    label = string(classesPredictions(idx(i)));
    title("Label: " + label)
end
%% Model Function

function [dlY1,dlY2,state] = model(parameters,dlX,doTraining,state)

% Convolution
weights = parameters.conv1.Weights;
bias = parameters.conv1.Bias;
dlY = dlconv(dlX,weights,bias);
dlY = dlBicubicInterpolation(dlY, [28 28]);

% Batch normalization, ReLU
offset = parameters.batchnorm1.Offset;
scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution, batch normalization (Skip connection)
weights = parameters.convSkip.Weights;
bias = parameters.convSkip.Bias;
dlYSkip = dlconv(dlY,weights,bias,'Stride',2);

offset = parameters.batchnormSkip.Offset;
scale = parameters.batchnormSkip.Scale;
trainedMean = state.batchnormSkip.TrainedMean;
trainedVariance = state.batchnormSkip.TrainedVariance;

if doTraining
    [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnormSkip.TrainedMean = trainedMean;
    state.batchnormSkip.TrainedVariance = trainedVariance;
else
    dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
end

% Convolution
weights = parameters.conv2.Weights;
bias = parameters.conv2.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same','Stride',2);

% Batch normalization, ReLU
offset = parameters.batchnorm2.Offset;
scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution
weights = parameters.conv3.Weights;
bias = parameters.conv3.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same');

% Batch normalization
offset = parameters.batchnorm3.Offset;
scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = dlYSkip + dlY;
dlY = relu(dlY);

weights = parameters.fc1.Weights;
bias = parameters.fc1.Bias;
dlY1 = fullyconnect(dlY,weights,bias);
dlY1 = softmax(dlY1);

weights = parameters.fc2.Weights;
bias = parameters.fc2.Bias;
dlY2 = fullyconnect(dlY,weights,bias);

end
%% Model Gradients Function

function [gradients,state,loss] = modelGradients(parameters,dlX,T1,T2,state)

doTraining = true;
[dlY1,dlY2,state] = model(parameters,dlX,doTraining,state);

lossLabels = crossentropy(dlY1,T1);
lossAngles = mse(dlY2,T2);

loss = lossLabels + 0.1*lossAngles;
gradients = dlgradient(loss,parameters);

end
%% Mini-Batch Preprocessing Function

function [X,Y,angle] = preprocessMiniBatch(XCell,YCell,angleCell)
    X = cat(4,XCell{:});
    Y = cat(2,YCell{:});
    angle = cat(2,angleCell{:});
    Y = onehotencode(Y,1);
        
end

%% 
function [output] = dlBicubicInterpolation(dlArray ,newSize)
    Ndim = size(size(dlArray), 2);
    if Ndim == 2
        output = BicubicInterpolation(dlArray, newSize);
        return;
    end
    if Ndim == 3
        output = dlarray(zeros([newSize size(dlArray, 3)]), dims(dlArray));
        for i = 1:ize(dlArray, 3)
            output(:, :, i) = BicubicInterpolation(dlArray(:, :, i), newSize);
        end
        return;
    end
    if Ndim == 4
        output = dlarray(zeros([newSize size(dlArray, 3) size(dlArray, 4)]), dims(dlArray));
        for i = 1:size(dlArray, 3)
            for j = 1:size(dlArray, 4)
                output(:, :, i, j) = BicubicInterpolation(dlArray(:, :, i, j), newSize);
            end
        end
        return;
    end
    error("Unsupported case!");
end

function [output] = BicubicInterpolation(input, newSize)
    originSize = size(input);
    newSizeX = newSize(1);
    newSizeY = newSize(2);
    inputDims = dims(input);
    output = dlarray(zeros([8 8]), inputDims(1:2));
    ratiox = originSize(1) / newSizeX;
    ratioy = originSize(2) / newSizeY;

    for y = 0:newSizeY - 1
        for x = 0:newSizeX - 1
            xMappingToOrigin = x * ratiox;
            yMappingToOrigin = y * ratioy;
            xMappingToOriginFloor = floor(xMappingToOrigin);
            yMappingToOriginFloor = floor(yMappingToOrigin);
            xMappingToOriginFrac = xMappingToOrigin - xMappingToOriginFloor;
            yMappingToOriginFrac = yMappingToOrigin - yMappingToOriginFloor;
            ndata = zeros(4, 4);
            for ndatay = -1:2
                for ndatax = -1:2
                    ndata(ndatax + 2, ndatay + 2) = input( ...
                        clip(xMappingToOriginFloor + ndatax, 0, originSize(1) - 1) + 1, ...
                        clip(yMappingToOriginFloor + ndatay, 0, originSize(2) - 1) + 1);
                end
            end
            output(x + 1, y + 1) = BicubicPolate(ndata, xMappingToOriginFrac, yMappingToOriginFrac);
        end
    end
end

function [output] = clip(input, lowerbound, upperbound)
    if (input > upperbound)
        output = upperbound;
        return;
    end
    if (input < lowerbound)
        output = lowerbound;
        return;
    end
    output = input;
end

function [output] = BicubicPolate(ndata, fracx, fracy)
    x1 = CubicPolate( ndata(1,1), ndata(2,1), ndata(3,1), ndata(4,1), fracx );
    x2 = CubicPolate( ndata(1,2), ndata(2,2), ndata(3,2), ndata(4,2), fracx );
    x3 = CubicPolate( ndata(1,3), ndata(2,3), ndata(3,3), ndata(4,3), fracx );
    x4 = CubicPolate( ndata(1,4), ndata(2,4), ndata(3,4), ndata(4,4), fracx );

    output = CubicPolate( x1, x2, x3, x4, fracy );
end

function [output] = CubicPolate(v0, v1, v2, v3, fracy )
    A = (v3-v2)-(v0-v1);
    B = (v0-v1)-A;
    C = v2-v0;
    D = v1;
    output =  D + fracy * (C + fracy * (B + fracy * A));
end

Информация о тестовой платформе

Версия Matlab: ‘9.10.0.1684407 (R2021a) Обновление 3’

Все предложения приветствуются.

Сводная информация:

  • На какой вопрос это продолжение?

    Реализация двумерной бикубической интерполяции в Matlab.

  • Какие изменения были внесены в код с момента последнего вопроса?

    Я пытаюсь перенести алгоритм двумерной бикубической интерполяции в Модель Функция в этом посте.

  • Почему запрашивается новый обзор?

    Если есть какие-то улучшения, пожалуйста, дайте мне знать.

Справка

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *