import math
import logging

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from . import deform_conv_cuda

logger = logging.getLogger('base')


class DeformConvFunction(Function):
    @staticmethod
    def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
                deformable_groups=1, im2col_step=64):
        if input is not None and input.dim() != 4:
            raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
                input.dim()))
        ctx.stride = _pair(stride)
        ctx.padding = _pair(padding)
        ctx.dilation = _pair(dilation)
        ctx.groups = groups
        ctx.deformable_groups = deformable_groups
        ctx.im2col_step = im2col_step

        ctx.save_for_backward(input, offset, weight)

        output = input.new_empty(
            DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))

        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones

        if not input.is_cuda:
            raise NotImplementedError
        else:
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
            deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
                                                      ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
                                                      weight.size(2), ctx.stride[1], ctx.stride[0],
                                                      ctx.padding[1], ctx.padding[0],
                                                      ctx.dilation[1], ctx.dilation[0], ctx.groups,
                                                      ctx.deformable_groups, cur_im2col_step)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        input, offset, weight = ctx.saved_tensors

        grad_input = grad_offset = grad_weight = None

        if not grad_output.is_cuda:
            raise NotImplementedError
        else:
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'

            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
                grad_input = torch.zeros_like(input)
                grad_offset = torch.zeros_like(offset)
                deform_conv_cuda.deform_conv_backward_input_cuda(
                    input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
                    weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
                    ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
                    ctx.deformable_groups, cur_im2col_step)

            if ctx.needs_input_grad[2]:
                grad_weight = torch.zeros_like(weight)
                deform_conv_cuda.deform_conv_backward_parameters_cuda(
                    input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
                    weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
                    ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
                    ctx.deformable_groups, 1, cur_im2col_step)

        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)

    @staticmethod
    def _output_size(input, weight, padding, dilation, stride):
        channels = weight.size(0)
        output_size = (input.size(0), channels)
        for d in range(input.dim() - 2):
            in_size = input.size(d + 2)
            pad = padding[d]
            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
            stride_ = stride[d]
            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
        if not all(map(lambda s: s > 0, output_size)):
            raise ValueError("convolution input is too small (output would be {})".format('x'.join(
                map(str, output_size))))
        return output_size


class ModulatedDeformConvFunction(Function):
    @staticmethod
    def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
                groups=1, deformable_groups=1):
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups
        ctx.deformable_groups = deformable_groups
        ctx.with_bias = bias is not None
        if not ctx.with_bias:
            bias = input.new_empty(1)  # fake tensor
        if not input.is_cuda:
            raise NotImplementedError
        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
                or input.requires_grad:
            ctx.save_for_backward(input, offset, mask, weight, bias)
        output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
        deform_conv_cuda.modulated_deform_conv_cuda_forward(
            input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
            weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
            ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        if not grad_output.is_cuda:
            raise NotImplementedError
        input, offset, mask, weight, bias = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        grad_offset = torch.zeros_like(offset)
        grad_mask = torch.zeros_like(mask)
        grad_weight = torch.zeros_like(weight)
        grad_bias = torch.zeros_like(bias)
        deform_conv_cuda.modulated_deform_conv_cuda_backward(
            input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
            grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
            ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
            ctx.groups, ctx.deformable_groups, ctx.with_bias)
        if not ctx.with_bias:
            grad_bias = None

        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
                None)

    @staticmethod
    def _infer_shape(ctx, input, weight):
        n = input.size(0)
        channels_out = weight.size(0)
        height, width = input.shape[2:4]
        kernel_h, kernel_w = weight.shape[2:4]
        height_out = (height + 2 * ctx.padding - (ctx.dilation *
                                                  (kernel_h - 1) + 1)) // ctx.stride + 1
        width_out = (width + 2 * ctx.padding - (ctx.dilation *
                                                (kernel_w - 1) + 1)) // ctx.stride + 1
        return n, channels_out, height_out, width_out


deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply


class DeformConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, deformable_groups=1, bias=False):
        super(DeformConv, self).__init__()

        assert not bias
        assert in_channels % groups == 0, \
            'in_channels {} cannot be divisible by groups {}'.format(
                in_channels, groups)
        assert out_channels % groups == 0, \
            'out_channels {} cannot be divisible by groups {}'.format(
                out_channels, groups)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.deformable_groups = deformable_groups

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))

        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x, offset):
        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
                           self.groups, self.deformable_groups)


class DeformConvPack(DeformConv):
    def __init__(self, *args, **kwargs):
        super(DeformConvPack, self).__init__(*args, **kwargs)

        self.conv_offset = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
            kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

    def forward(self, x):
        offset = self.conv_offset(x)
        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
                           self.groups, self.deformable_groups)


class ModulatedDeformConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, deformable_groups=1, bias=True):
        super(ModulatedDeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.deformable_groups = deformable_groups
        self.with_bias = bias

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x, offset, mask):
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
                                     self.padding, self.dilation, self.groups,
                                     self.deformable_groups)


class ModulatedDeformConvPack(ModulatedDeformConv):
    def __init__(self, *args, extra_offset_mask=False, **kwargs):
        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)

        self.extra_offset_mask = extra_offset_mask
        self.conv_offset_mask = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
            kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset_mask.weight.data.zero_()
        self.conv_offset_mask.bias.data.zero_()

    def forward(self, x):
        if self.extra_offset_mask:
            # x = [input, features]
            out = self.conv_offset_mask(x[1])
            x = x[0]
        else:
            out = self.conv_offset_mask(x)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)

        offset_mean = torch.mean(torch.abs(offset))
        if offset_mean > 100:
            logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))

        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
                                     self.padding, self.dilation, self.groups,
                                     self.deformable_groups) 

Python Online Compiler

Write, Run & Share Python code online using OneCompiler's Python online compiler for free. It's one of the robust, feature-rich online compilers for python language, supporting both the versions which are Python 3 and Python 2.7. Getting started with the OneCompiler's Python editor is easy and fast. The editor shows sample boilerplate code when you choose language as Python or Python2 and start coding.

Taking inputs (stdin)

OneCompiler's python online editor supports stdin and users can give inputs to programs using the STDIN textbox under the I/O tab. Following is a sample python program which takes name as input and print your name with hello.

import sys
name = sys.stdin.readline()
print("Hello "+ name)

About Python

Python is a very popular general-purpose programming language which was created by Guido van Rossum, and released in 1991. It is very popular for web development and you can build almost anything like mobile apps, web apps, tools, data analytics, machine learning etc. It is designed to be simple and easy like english language. It's is highly productive and efficient making it a very popular language.

Tutorial & Syntax help

Loops

1. If-Else:

When ever you want to perform a set of operations based on a condition IF-ELSE is used.

if conditional-expression
    #code
elif conditional-expression
    #code
else:
    #code

Note:

Indentation is very important in Python, make sure the indentation is followed correctly

2. For:

For loop is used to iterate over arrays(list, tuple, set, dictionary) or strings.

Example:

mylist=("Iphone","Pixel","Samsung")
for i in mylist:
    print(i)

3. While:

While is also used to iterate a set of statements based on a condition. Usually while is preferred when number of iterations are not known in advance.

while condition  
    #code 

Collections

There are four types of collections in Python.

1. List:

List is a collection which is ordered and can be changed. Lists are specified in square brackets.

Example:

mylist=["iPhone","Pixel","Samsung"]
print(mylist)

2. Tuple:

Tuple is a collection which is ordered and can not be changed. Tuples are specified in round brackets.

Example:

myTuple=("iPhone","Pixel","Samsung")
print(myTuple)

Below throws an error if you assign another value to tuple again.

myTuple=("iPhone","Pixel","Samsung")
print(myTuple)
myTuple[1]="onePlus"
print(myTuple)

3. Set:

Set is a collection which is unordered and unindexed. Sets are specified in curly brackets.

Example:

myset = {"iPhone","Pixel","Samsung"}
print(myset)

4. Dictionary:

Dictionary is a collection of key value pairs which is unordered, can be changed, and indexed. They are written in curly brackets with key - value pairs.

Example:

mydict = {
    "brand" :"iPhone",
    "model": "iPhone 11"
}
print(mydict)

Supported Libraries

Following are the libraries supported by OneCompiler's Python compiler

NameDescription
NumPyNumPy python library helps users to work on arrays with ease
SciPySciPy is a scientific computation library which depends on NumPy for convenient and fast N-dimensional array manipulation
SKLearn/Scikit-learnScikit-learn or Scikit-learn is the most useful library for machine learning in Python
PandasPandas is the most efficient Python library for data manipulation and analysis
DOcplexDOcplex is IBM Decision Optimization CPLEX Modeling for Python, is a library composed of Mathematical Programming Modeling and Constraint Programming Modeling