export class Conv2D extends Module {
  public W: Tensor;
  public b: Tensor;
  public stride: number[];
  public padding: number[];
  public dilation: number[];
  public groups: number;
  public has_bias: boolean;

  /**
   * Enhanced 2D Convolutional Layer with additional features like stride, padding, dilation, and groups.
   *
   * @param {number} in_channels - Number of input channels.
   * @param {number} out_channels - Number of output channels (filters).
   * @param {number | [number, number]} kernel_size - Size of the convolving kernel.
   * @param {number | [number, number]} stride - Stride of the convolution.
   * @param {number | [number, number]} padding - Padding added to the input.
   * @param {number | [number, number]} dilation - Spacing between kernel elements.
   * @param {number} groups - Number of groups for grouped convolutions.
   * @param {boolean} bias - Whether to include a bias term.
   * @param {string} device - Device to perform Tensor operations. Either "gpu" or "cpu".
   */

    constructor(
      in_channels: number,
      out_channels: number,
      kernel_size: number | [number, number],
      stride: number | [number, number] = 1,
      padding: string | [number, number] = "same", // Allow string 'same' or [number, number]
      dilation: number | [number, number] = 1,
      groups: number = 1,
      bias: boolean = true,
      device: string = "cpu"
    ) {
      super();

      // Ensure kernel size, stride, padding, and dilation are arrays
      const [kh, kw] = Array.isArray(kernel_size) ? kernel_size : [kernel_size, kernel_size];
      const [sh, sw] = Array.isArray(stride) ? stride : [stride, stride];
      const [dh, dw] = Array.isArray(dilation) ? dilation : [dilation, dilation];

      // Automatically calculate 'same' padding if specified
      let ph, pw;
      if (padding === "same") {
        ph = Math.floor(((kh - 1) * dh + 1 - sh) / 2);
        pw = Math.floor(((kw - 1) * dw + 1 - sw) / 2);
      } else if (Array.isArray(padding)) {
        [ph, pw] = padding;
      } else {
        ph = pw = padding;
      }

      // Validation checks
      if (sh > kh || sw > kw) {
        throw new Error("Stride cannot be larger than the kernel size.");
      }
      if (dh > kh || dw > kw) {
        throw new Error("Dilation cannot be larger than the kernel size.");
      }
      if (ph < 0 || pw < 0) {
        throw new Error("Padding values cannot be negative.");
      }
      if (in_channels % groups !== 0) {
        throw new Error("in_channels must be divisible by groups.");
      }

      // Initialize convolution properties
      this.stride = [sh, sw];
      this.padding = [ph, pw];
      this.dilation = [dh, dw];
      this.groups = groups;
      this.has_bias = bias;

      // Weight shape based on grouped convolution
      const weight_shape = [out_channels, Math.floor(in_channels / groups), kh, kw];
      this.W = randn(weight_shape, true, device, false);
      if (bias) {
        this.b = zeros([out_channels], true);
      }
    }

  /**
   * Performs forward pass through the Conv2D layer.
   * @param {Tensor} x - Input tensor of shape [batch, in_channels, height, width].
   * @returns {Tensor} Output tensor after applying convolution.
   */

    forward(x: Tensor): Tensor {

      // Handle padding
      if (this.padding[0] > 0 || this.padding[1] > 0) {
        x = x.pad([[0, 0], [0, 0], [this.padding[0], this.padding[0]], [this.padding[1], this.padding[1]]]);
      }


      // Extract dimensions and initialize output buffer
      const [batch, in_channels, height, width] = x.shape;
      const [out_channels, _, kernel_height, kernel_width] = this.W.shape;
      const out_height = Math.floor((height + 2 * this.padding[0] - this.dilation[0] * (kernel_height - 1) - 1) / this.stride[0] + 1);
      const out_width = Math.floor((width + 2 * this.padding[1] - this.dilation[1] * (kernel_width - 1) - 1) / this.stride[1] + 1);
      const outputData = new Array(batch).fill().map(() =>
        new Array(out_channels).fill().map(() =>
          new Array(out_height).fill().map(() => new Array(out_width).fill(0))
        )
      );

      for (let b = 0; b < batch; b++) {
        for (let oc = 0; oc < out_channels; oc++) {
          for (let i = 0; i < out_height; i++) {
            for (let j = 0; j < out_width; j++) {
                  const h_start = i * this.stride[0];
                  const w_start = j * this.stride[1];


                // Check for out-of-bounds conditions
                if (h_start + kernel_height > height || w_start + kernel_width > width) {
                    continue;  // Skip iteration if the region goes out of bounds
                }

              const region = x.slice(
                [b, 0, h_start, w_start],
                [1, in_channels / this.groups, kernel_height, kernel_width]
              );

              const filter = this.W.slice(
                [oc, 0, 0, 0],
                [1, in_channels / this.groups, kernel_height, kernel_width]
              );

                 // Check shapes and content for debugging
                //console.log("Region shape:", region.shape, "Filter shape:", filter.shape);
                const conv_result_mul = region.mul(filter);

                //console.log("Convolution result mul:", conv_result_mul.shape);
                const conv_result=conv_result_mul.sum().sum().sum();
                // Check convolution result
                //console.log("Convolution result:", conv_result.data);

                // Assign the result to outputData
                outputData[b][oc][i][j] = conv_result.data[0] + (this.has_bias ? this.b.data[oc] : 0);
                //debugger;

            }
          }
        }
      }
      //console.log("returning tensor",outputData);

      return new Tensor(outputData);
    }


