这里回顾HW3,这次的主要内容是实现一个简单的线性代数库。

课程主页:

参考资料:

重点回顾

利用stride描述一般的内存访问:

A[i, j] => Adata[i * A.strides[0] + j * A.strides[1]] 

通过变换stride,就可以实现reshape的功能,例如形状(1, 2)变成(1, 1, 2)可以增加stride=0。

Part 1: Python array operations

实现reshape, permute, broadcast_to, __getitem__函数,主要基于stride。

stride:

高维数组在内存中的存储方式依然是一维,按照行优先可以表示为:

A[i, j] => Adata[i * A.shape[1] + j] 

按照列优先则可以表示为:

A[i, j] => Adata[j * A.shape[0] + i] 

更一般的情形,需要引入stride:

A[i, j] => Adata[i * A.strides[0] + j * A.strides[1]] 

stride的引入可以让reshape更加方便,如形状(1, 2)变成(1, 1, 2)可以增加stride=0:

A[i, j, k] = Adata[i * A.strides[0] + j * 0 + k * A.strides[2]] 

更一般的情况是增加offset,即初始位置的偏移:

A[i, j] => Adata[offset + i * A.strides[0] + j * A.strides[1]] 

其他注意点在于需要计算输出形状,整体代码如下:

def reshape(self, new_shape):
    """
    Reshape the matrix without copying memory.  This will return a matrix
    that corresponds to a reshaped array but points to the same memory as
    the original array.
    Raises:
        ValueError if product of current shape is not equal to the product
        of the new shape, or if the matrix is not compact.
    Args:
        new_shape (tuple): new shape of the array
    Returns:
        NDArray : reshaped array; this will point to the same memory as the original NDArray.
    """

    ### BEGIN YOUR SOLUTION
    origin_size = self.size
    new_size = prod(new_shape)
    if origin_size != new_size:
        raise ValueError
    new_strides = self.compact_strides(new_shape)
    
    return NDArray.make(
        shape=new_shape, 
        strides=new_strides, 
        device=self._device, 
        handle=self._handle,
        offset=self._offset,
    )
    ### END YOUR SOLUTION

def permute(self, new_axes):
    """
    Permute order of the dimensions.  new_axes describes a permutation of the
    existing axes, so e.g.:
        - If we have an array with dimension "BHWC" then .permute((0,3,1,2))
        would convert this to "BCHW" order.
        - For a 2D array, .permute((1,0)) would transpose the array.
    Like reshape, this operation should not copy memory, but achieves the
    permuting by just adjusting the shape/strides of the array.  That is,
    it returns a new array that has the dimensions permuted as desired, but
    which points to the same memory as the original array.
    Args:
        new_axes (tuple): permutation order of the dimensions
    Returns:
        NDarray : new NDArray object with permuted dimensions, pointing
        to the same memory as the original NDArray (i.e., just shape and
        strides changed).
    """
    ### BEGIN YOUR SOLUTION
    new_shape = [0 for _ in range(self.ndim)]
    new_strides = [0 for _ in range(self.ndim)]
    for i, j in enumerate(new_axes):
        new_shape[i] = self._shape[j]
        new_strides[i] = self._strides[j]
    
    return NDArray.make(
        shape=new_shape, 
        strides=tuple(new_strides), 
        device=self._device, 
        handle=self._handle,
        offset=self._offset,
    )
    ### END YOUR SOLUTION

def broadcast_to(self, new_shape):
    """
    Broadcast an array to a new shape.  new_shape's elements must be the
    same as the original shape, except for dimensions in the self where
    the size = 1 (which can then be broadcast to any size).  As with the
    previous calls, this will not copy memory, and just achieves
    broadcasting by manipulating the strides.
    Raises:
        assertion error if new_shape[i] != shape[i] for all i where
        shape[i] != 1
    Args:
        new_shape (tuple): shape to broadcast to
    Returns:
        NDArray: the new NDArray object with the new broadcast shape; should
        point to the same memory as the original array.
    """

    ### BEGIN YOUR SOLUTION
    if len(new_shape) != self.ndim:
        raise AssertionError
    new_strides = [0 for _ in range(self.ndim)]
    for i in range(self.ndim):
        if new_shape[i] == self._shape[i]:
            new_strides[i] = self._strides[i]
        elif self._shape[i] == 1:
            new_strides[i] = 0
        else:
            raise AssertionError
    
    return NDArray.make(
        shape=new_shape, 
        strides=tuple(new_strides), 
        device=self._device, 
        handle=self._handle,
        offset=self._offset,
    )
    ### END YOUR SOLUTION
    
def __getitem__(self, idxs):
    """
    The __getitem__ operator in Python allows us to access elements of our
    array.  When passed notation such as a[1:5,:-1:2,4,:] etc, Python will
    convert this to a tuple of slices and integers (for singletons like the
    '4' in this example).  Slices can be a bit odd to work with (they have
    three elements .start .stop .step), which can be None or have negative
    entries, so for simplicity we wrote the code for you to convert these
    to always be a tuple of slices, one of each dimension.
    For this tuple of slices, return an array that subsets the desired
    elements.  As before, this can be done entirely through compute a new
    shape, stride, and offset for the new "view" into the original array,
    pointing to the same memory
    Raises:
        AssertionError if a slice has negative size or step, or if number
        of slices is not equal to the number of dimension (the stub code
        already raises all these errors.
    Args:
        idxs tuple: (after stub code processes), a tuple of slice elements
        corresponding to the subset of the matrix to get
    Returns:
        NDArray: a new NDArray object corresponding to the selected
        subset of elements.  As before, this should not copy memory but just
        manipulate the shape/strides/offset of the new array, referencing
        the same array as the original one.
    """

    # handle singleton as tuple, everything as slices
    if not isinstance(idxs, tuple):
        idxs = (idxs,)
    idxs = tuple(
        [
            self.process_slice(s, i) if isinstance(s, slice) else slice(s, s + 1, 1)
            for i, s in enumerate(idxs)
        ]
    )
    assert len(idxs) == self.ndim, "Need indexes equal to number of dimensions"

    ### BEGIN YOUR SOLUTION
    new_shape = [0 for _ in range(self.ndim)]
    new_strides = [0 for _ in range(self.ndim)]
    for i in range(self.ndim):
        # 不整除情形
        if (idxs[i].stop - idxs[i].start) % idxs[i].step:
            new_shape[i] += 1
        new_strides[i] = self.strides[i] * idxs[i].step
    new_offset = 0
    for i in range(self.ndim):
        new_offset += self._strides[i] * idxs[i].start
    
    return NDArray.make(
        shape=new_shape, 
        strides=tuple(new_strides), 
        device=self._device, 
        handle=self._handle,
        offset=new_offset,
    )
    ### END YOUR SOLUTION

Part 2: CPU Backend - Compact and setitem

主要难点就是如何用一个循环遍历$n$个维度,核心思想类似于大整数加法,将第$i$维的坐标理解为大整数的第$i$位:

void Compact(const AlignedArray& a, AlignedArray* out, std::vector<uint32_t> shape,
             std::vector<uint32_t> strides, size_t offset) {
  /**
   * Compact an array in memory
   *
   * Args:
   *   a: non-compact representation of the array, given as input
   *   out: compact version of the array to be written
   *   shape: shapes of each dimension for a and out
   *   strides: strides of the *a* array (not out, which has compact strides)
   *   offset: offset of the *a* array (not out, which has zero offset, being compact)
   *
   * Returns:
   *  void (you need to modify out directly, rather than returning anything; this is true for all the
   *  function will implement here, so we won't repeat this note.)
   */
  /// BEGIN YOUR SOLUTION
  // out不需要申请
  uint32_t size = out->size;
  int n = shape.size();
  std::vector<uint32_t> indexs(n, 0);
  // 从末位到第一位
  int i = n - 1;
  for (int j = 0; j < size; j++) {
    int k = offset;
    for (int l = 0; l < n; l++) {
      k += indexs[l] * strides[l];
    }
    out->ptr[j] = a.ptr[k];
    // update
    indexs[i]++;
    // 进位
    int carry = indexs[i] / shape[i];
    while (carry) {
      indexs[i] = indexs[i] % shape[i];
      int i1 = (i - 1 + n) % n;
      if (i1 < i) {
        indexs[i1] += carry;
        carry = indexs[i1] / shape[i1];
        i = i1;
      } else {
        break;
      }
    }
    // 回到最低位
    i = n - 1;
  }
  /// END YOUR SOLUTION
}

void EwiseSetitem(const AlignedArray& a, AlignedArray* out, std::vector<uint32_t> shape,
                  std::vector<uint32_t> strides, size_t offset) {
  /**
   * Set items in a (non-compact) array
   *
   * Args:
   *   a: _compact_ array whose items will be written to out
   *   out: non-compact array whose items are to be written
   *   shape: shapes of each dimension for a and out
   *   strides: strides of the *out* array (not a, which has compact strides)
   *   offset: offset of the *out* array (not a, which has zero offset, being compact)
   */
  /// BEGIN YOUR SOLUTION
  // out不需要申请
  uint32_t size = a.size;
  int n = shape.size();
  std::vector<uint32_t> indexs(n, 0);
  // 从末位到第一位
  int i = n - 1;
  for (int j = 0; j < size; j++) {
    int k = offset;
    for (int l = 0; l < n; l++) {
      k += indexs[l] * strides[l];
    }
    out->ptr[k] = a.ptr[j];
    // update
    indexs[i]++;
    // 进位
    int carry = indexs[i] / shape[i];
    while (carry) {
      indexs[i] = indexs[i] % shape[i];
      int i1 = (i - 1 + n) % n;
      if (i1 < i) {
        indexs[i1] += carry;
        carry = indexs[i1] / shape[i1];
        i = i1;
      } else {
        break;
      }
    }
    // 回到最低位
    i = n - 1;
  }
  /// END YOUR SOLUTION
}

void ScalarSetitem(const size_t size, scalar_t val, AlignedArray* out, std::vector<uint32_t> shape,
                   std::vector<uint32_t> strides, size_t offset) {
  /**
   * Set items is a (non-compact) array
   *
   * Args:
   *   size: number of elements to write in out array (note that this will note be the same as
   *         out.size, because out is a non-compact subset array);  it _will_ be the same as the
   *         product of items in shape, but convenient to just pass it here.
   *   val: scalar value to write to
   *   out: non-compact array whose items are to be written
   *   shape: shapes of each dimension of out
   *   strides: strides of the out array
   *   offset: offset of the out array
   */

  /// BEGIN YOUR SOLUTION
  int n = shape.size();
  std::vector<uint32_t> indexs(n, 0);
  // 从末位到第一位
  int i = n - 1;
  for (int j = 0; j < size; j++) {
    int k = offset;
    for (int l = 0; l < n; l++) {
      k += indexs[l] * strides[l];
    }
    out->ptr[k] = val;
    // update
    indexs[i]++;
    // 进位
    int carry = indexs[i] / shape[i];
    while (carry) {
      indexs[i] = indexs[i] % shape[i];
      int i1 = (i - 1 + n) % n;
      if (i1 < i) {
        indexs[i1] += carry;
        carry = indexs[i1] / shape[i1];
        i = i1;
      } else {
        break;
      }
    }
    // 回到最低位
    i = n - 1;
  }
  /// END YOUR SOLUTION
}

Part 3: CPU Backend - Elementwise and scalar operations

这部分没啥难度,主要是使用函数模板:

/// BEGIN YOUR SOLUTION
template <typename F>
void EwiseFun(const AlignedArray& a, const AlignedArray& b, AlignedArray* out, F f) {
  for (size_t i = 0; i < a.size; i++) {
    out->ptr[i] = f(a.ptr[i], b.ptr[i]);
  }
}

template <typename F>
void ScalarFun(const AlignedArray& a, scalar_t val, AlignedArray* out, F f) {
  for (size_t i = 0; i < a.size; i++) {
    out->ptr[i] = f(a.ptr[i], val);
  }
}

template <typename F>
void SingleEwiseFun(const AlignedArray& a, AlignedArray* out, F f) {
  for (size_t i = 0; i < a.size; i++) {
    out->ptr[i] = f(a.ptr[i]);
  }
}

// EwiseMul, ScalarMul
scalar_t Mul(scalar_t a, scalar_t b) {
  return a * b;
}

void EwiseMul(const AlignedArray& a, const AlignedArray& b, AlignedArray* out) {
  return EwiseFun(a, b, out, Mul);
}

void ScalarMul(const AlignedArray& a, scalar_t val, AlignedArray* out) {
  return ScalarFun(a, val, out, Mul);
}

// EwiseDiv, ScalarDiv
scalar_t Div(scalar_t a, scalar_t b) {
  return a / b;
}

void EwiseDiv(const AlignedArray& a, const AlignedArray& b, AlignedArray* out) {
  return EwiseFun(a, b, out, Div);
}

void ScalarDiv(const AlignedArray& a, scalar_t val, AlignedArray* out) {
  return ScalarFun(a, val, out, Div);
}

// ScalarPower
scalar_t Power(scalar_t a, scalar_t b) {
  return pow(a, b);
}

void ScalarPower(const AlignedArray& a, scalar_t val, AlignedArray* out) {
  return ScalarFun(a, val, out, Power);
}

// EwiseMaximum, ScalarMaximum
scalar_t Maximum(scalar_t a, scalar_t b) {
  return std::max(a, b);
}

void EwiseMaximum(const AlignedArray& a, const AlignedArray& b, AlignedArray* out) {
  return EwiseFun(a, b, out, Maximum);
}

void ScalarMaximum(const AlignedArray& a, scalar_t val, AlignedArray* out) {
  return ScalarFun(a, val, out, Maximum);
}

// EwiseEq, ScalarEq
scalar_t Eq(scalar_t a, scalar_t b) {
  return std::fabs(a - b) < 1e-6;
}

void EwiseEq(const AlignedArray& a, const AlignedArray& b, AlignedArray* out) {
  return EwiseFun(a, b, out, Eq);
}

void ScalarEq(const AlignedArray& a, scalar_t val, AlignedArray* out) {
  return ScalarFun(a, val, out, Eq);
}

// EwiseGe, ScalarGe
scalar_t Ge(scalar_t a, scalar_t b) {
  return a >= b;
}

void EwiseGe(const AlignedArray& a, const AlignedArray& b, AlignedArray* out) {
  return EwiseFun(a, b, out, Ge);
}

void ScalarGe(const AlignedArray& a, scalar_t val, AlignedArray* out) {
  return ScalarFun(a, val, out, Ge);
}

// EwiseLog
scalar_t Log(scalar_t a) {
  return log(a);
}

void EwiseLog(const AlignedArray& a, AlignedArray* out) {
  return SingleEwiseFun(a, out, Log);
}

// EwiseExp
scalar_t Exp(scalar_t a) {
  return exp(a);
}

void EwiseExp(const AlignedArray& a, AlignedArray* out) {
  return SingleEwiseFun(a, out, Exp);
}

// EwiseTanh
scalar_t Tanh(scalar_t a) {
  return tanh(a);
}

void EwiseTanh(const AlignedArray& a, AlignedArray* out) {
  return SingleEwiseFun(a, out, Tanh);
}

/// END YOUR SOLUTION

Part 4: CPU Backend - Reductions

依旧难度不大,转换坐标即可:

void ReduceMax(const AlignedArray& a, AlignedArray* out, size_t reduce_size) {
  /**
   * Reduce by taking maximum over `reduce_size` contiguous blocks.
   *
   * Args:
   *   a: compact array of size a.size = out.size * reduce_size to reduce over
   *   out: compact array to write into
   *   reduce_size: size of the dimension to reduce over
   */

  /// BEGIN YOUR SOLUTION
  size_t n = out->size;
  for (int i = 0; i < n; i++) {
    int j = i * reduce_size;
    scalar_t res = a.ptr[j];
    for (int k = 0; k < reduce_size; k++) {
      res = std::max(res, a.ptr[j + k]);
    }
    out->ptr[i] = res;
  }
  /// END YOUR SOLUTION
}

void ReduceSum(const AlignedArray& a, AlignedArray* out, size_t reduce_size) {
  /**
   * Reduce by taking sum over `reduce_size` contiguous blocks.
   *
   * Args:
   *   a: compact array of size a.size = out.size * reduce_size to reduce over
   *   out: compact array to write into
   *   reduce_size: size of the dimension to reduce over
   */

  /// BEGIN YOUR SOLUTION
  size_t n = out->size;
  for (int i = 0; i < n; i++) {
    int j = i * reduce_size;
    scalar_t res = 0;
    for (int k = 0; k < reduce_size; k++) {
      res += a.ptr[j + k];
    }
    out->ptr[i] = res;
  }
  /// END YOUR SOLUTION
}

Part 5: CPU Backend - Matrix multiplication

矩阵乘法和分块矩阵乘法:

void Matmul(const AlignedArray& a, const AlignedArray& b, AlignedArray* out, uint32_t m, uint32_t n,
            uint32_t p) {
  /**
   * Multiply two (compact) matrices into an output (also compact) matrix.  For this implementation
   * you can use the "naive" three-loop algorithm.
   *
   * Args:
   *   a: compact 2D array of size m x n
   *   b: compact 2D array of size n x p
   *   out: compact 2D array of size m x p to write the output to
   *   m: rows of a / out
   *   n: columns of a / rows of b
   *   p: columns of b / out
   */

  /// BEGIN YOUR SOLUTION
  for (int i = 0; i < m; i++) {
    for (int j = 0; j < p; j++) {
      int r = i * p + j;
      out->ptr[r] = 0;
      for (int k = 0; k < n; k++) {
        int s = i * n + k;
        int t = k * p + j;
        out->ptr[r] += a.ptr[s] * b.ptr[t];
      }
    }
  }
  /// END YOUR SOLUTION
}

inline void AlignedDot(const float* __restrict__ a,
                       const float* __restrict__ b,
                       float* __restrict__ out) {

  /**
   * Multiply together two TILE x TILE matrices, and _add _the result to out (it is important to add
   * the result to the existing out, which you should not set to zero beforehand).  We are including
   * the compiler flags here that enable the compile to properly use vector operators to implement
   * this function.  Specifically, the __restrict__ keyword indicates to the compile that a, b, and
   * out don't have any overlapping memory (which is necessary in order for vector operations to be
   * equivalent to their non-vectorized counterparts (imagine what could happen otherwise if a, b,
   * and out had overlapping memory).  Similarly the __builtin_assume_aligned keyword tells the
   * compiler that the input array will be aligned to the appropriate blocks in memory, which also
   * helps the compiler vectorize the code.
   *
   * Args:
   *   a: compact 2D array of size TILE x TILE
   *   b: compact 2D array of size TILE x TILE
   *   out: compact 2D array of size TILE x TILE to write to
   */

  a = (const float*)__builtin_assume_aligned(a, TILE * ELEM_SIZE);
  b = (const float*)__builtin_assume_aligned(b, TILE * ELEM_SIZE);
  out = (float*)__builtin_assume_aligned(out, TILE * ELEM_SIZE);

  /// BEGIN YOUR SOLUTION
  for (int i = 0; i < TILE; i++) {
    for (int j = 0; j < TILE; j++) {
      int r = i * TILE + j;
      for (int k = 0; k < TILE; k++) {
        int s = i * TILE + k;
        int t = k * TILE + j;
        out[r] += a[s] * b[t];
      }
    }
  }
  /// END YOUR SOLUTION
}

void MatmulTiled(const AlignedArray& a, const AlignedArray& b, AlignedArray* out, uint32_t m,
                 uint32_t n, uint32_t p) {
  /**
   * Matrix multiplication on tiled representations of array.  In this setting, a, b, and out
   * are all *4D* compact arrays of the appropriate size, e.g. a is an array of size
   *   a[m/TILE][n/TILE][TILE][TILE]
   * You should do the multiplication tile-by-tile to improve performance of the array (i.e., this
   * function should call `AlignedDot()` implemented above).
   *
   * Note that this function will only be called when m, n, p are all multiples of TILE, so you can
   * assume that this division happens without any remainder.
   *
   * Args:
   *   a: compact 4D array of size m/TILE x n/TILE x TILE x TILE
   *   b: compact 4D array of size n/TILE x p/TILE x TILE x TILE
   *   out: compact 4D array of size m/TILE x p/TILE x TILE x TILE to write to
   *   m: rows of a / out
   *   n: columns of a / rows of b
   *   p: columns of b / out
   *
   */
  /// BEGIN YOUR SOLUTION
  size_t m1 = m / TILE;
  size_t n1 = n / TILE;
  size_t p1 = p / TILE;
  size_t TILE2 = TILE * TILE;
  float *a1 = new float[TILE2];
  float *b1 = new float[TILE2];
  float *out1 = new float[TILE2];
  for (int i = 0; i < m1; i++) {
    for (int j = 0; j < p1; j++) {
      // 初始化为0
      for (int u = 0; u < TILE2; u++) {
        out1[u] = 0;
      }
      for (int k = 0; k < n1; k++) {
        int s = (i * n1 + k) * TILE2;
        int t = (k * p1 + j) * TILE2;
        // 复制到a, b中
        for (int u = 0; u < TILE2; u++) {
          a1[u] = a.ptr[s + u];
          b1[u] = b.ptr[t + u];
        }
        AlignedDot(a1, b1, out1);
      }
      // 复制回out
      int r = (i * p1 + j) * TILE2;
      for (int u = 0; u < TILE2; u++) {
        out->ptr[r + u] = out1[u];
      }
    }
  }
  /// END YOUR SOLUTION
}

Part 6: CUDA Backend - Compact and setitem

实现cuda版本,需要参考课件中cuda的使用方式:

__global__ void CompactKernel(const scalar_t* a, scalar_t* out, size_t size, CudaVec shape,
                              CudaVec strides, size_t offset) {
  /**
   * The CUDA kernel for the compact opeation.  This should effectively map a single entry in the 
   * non-compact input a, to the corresponding item (at location gid) in the compact array out.
   * 
   * Args:
   *   a: CUDA pointer to a array
   *   out: CUDA point to out array
   *   size: size of out array
   *   shape: vector of shapes of a and out arrays (of type CudaVec, for past passing to CUDA kernel)
   *   strides: vector of strides of out array
   *   offset: offset of out array
   */
  // out的索引
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;

  /// BEGIN YOUR SOLUTION
  if (gid < size) {
    int n = shape.size;
    int i1 = offset;
    int j1 = gid;
    for (int i = n - 1; i >= 0; i--) {
      i1 += (j1 % shape.data[i]) * strides.data[i];
      j1 /= shape.data[i];
    }
    out[gid] = a[i1];
  }
  /// END YOUR SOLUTION
}

void Compact(const CudaArray& a, CudaArray* out, std::vector<uint32_t> shape,
             std::vector<uint32_t> strides, size_t offset) {
  /**
   * Compact an array in memory.  Unlike the C++ version, in CUDA this will primarily call the 
   * relevant CUDA kernel.  In this case, we illustrate how you should set this up (i.e., we give 
   * you the code for this fuction, and also the prototype for the CompactKernel() function).  For
   * the functions after this, however, you'll need to define these kernels as you see fit to 
   * execute the underlying function.
   * 
   * Args:
   *   a: non-compact represntation of the array, given as input
   *   out: compact version of the array to be written
   *   shape: shapes of each dimension for a and out
   *   strides: strides of the *a* array (not out, which has compact strides)
   *   offset: offset of the *a* array (not out, which has zero offset, being compact)
   */

  // Nothing needs to be added here
  CudaDims dim = CudaOneDim(out->size);
  CompactKernel<<<dim.grid, dim.block>>>(a.ptr, out->ptr, out->size, VecToCuda(shape),
                                         VecToCuda(strides), offset);
}


__global__ void EwiseSetitemKernel(const scalar_t* a, scalar_t* out, size_t size, CudaVec shape,
                              CudaVec strides, size_t offset) {
  /**
   * The CUDA kernel for the ewiseSetitem opeation.  This should effectively map a single entry in the 
   * non-compact input a, to the corresponding item (at location gid) in the compact array out.
   * 
   * Args:
   *   a: CUDA pointer to a array
   *   out: CUDA point to out array
   *   size: size of out array
   *   shape: vector of shapes of a and out arrays (of type CudaVec, for past passing to CUDA kernel)
   *   strides: vector of strides of out array
   *   offset: offset of out array
   */
  // a的索引
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;

  /// BEGIN YOUR SOLUTION
  if (gid < size) {
    int n = shape.size;
    int i1 = offset;
    int j1 = gid;
    for (int i = n - 1; i >= 0; i--) {
      i1 += (j1 % shape.data[i]) * strides.data[i];
      j1 /= shape.data[i];
    }
    out[i1] = a[gid];
  }
  /// END YOUR SOLUTION
}

void EwiseSetitem(const CudaArray& a, CudaArray* out, std::vector<uint32_t> shape,
                  std::vector<uint32_t> strides, size_t offset) {
  /**
   * Set items in a (non-compact) array using CUDA.  You will most likely want to implement a
   * EwiseSetitemKernel() function, similar to those above, that will do the actual work.
   * 
   * Args:
   *   a: _compact_ array whose items will be written to out
   *   out: non-compact array whose items are to be written
   *   shape: shapes of each dimension for a and out
   *   strides: strides of the *out* array (not a, which has compact strides)
   *   offset: offset of the *out* array (not a, which has zero offset, being compact)
   */
  /// BEGIN YOUR SOLUTION
  CudaDims dim = CudaOneDim(a.size);
  EwiseSetitemKernel<<<dim.grid, dim.block>>>(a.ptr, out->ptr, a.size, VecToCuda(shape),
                                         VecToCuda(strides), offset);

  /// END YOUR SOLUTION
}


__global__ void ScalarSetitemKernel(const scalar_t val, scalar_t* out, size_t size, CudaVec shape,
                              CudaVec strides, size_t offset) {
  /**
   * The CUDA kernel for the ewiseSetitem opeation.  This should effectively map a single entry in the 
   * non-compact input a, to the corresponding item (at location gid) in the compact array out.
   * 
   * Args:
   *   a: CUDA pointer to a array
   *   out: CUDA point to out array
   *   size: size of out array
   *   shape: vector of shapes of a and out arrays (of type CudaVec, for past passing to CUDA kernel)
   *   strides: vector of strides of out array
   *   offset: offset of out array
   */
  // a的索引
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;

  /// BEGIN YOUR SOLUTION
  if (gid < size) {
    int n = shape.size;
    int i1 = offset;
    int j1 = gid;
    for (int i = n - 1; i >= 0; i--) {
      i1 += (j1 % shape.data[i]) * strides.data[i];
      j1 /= shape.data[i];
    }
    out[i1] = val;
  }
  /// END YOUR SOLUTION
}

void ScalarSetitem(size_t size, scalar_t val, CudaArray* out, std::vector<uint32_t> shape,
                   std::vector<uint32_t> strides, size_t offset) {
  /**
   * Set items is a (non-compact) array
   * 
   * Args:
   *   size: number of elements to write in out array (note that this will note be the same as
   *         out.size, because out is a non-compact subset array);  it _will_ be the same as the 
   *         product of items in shape, but covenient to just pass it here.
   *   val: scalar value to write to
   *   out: non-compact array whose items are to be written
   *   shape: shapes of each dimension of out
   *   strides: strides of the out array
   *   offset: offset of the out array
   */
  /// BEGIN YOUR SOLUTION
  CudaDims dim = CudaOneDim(out->size);
  ScalarSetitemKernel<<<dim.grid, dim.block>>>(val, out->ptr, out->size, VecToCuda(shape),
                                         VecToCuda(strides), offset);
  /// END YOUR SOLUTION
}

Part 7: CUDA Backend - Elementwise and scalar operations

/**
 * In the code the follows, use the above template to create analogous elementise
 * and and scalar operators for the following functions.  See the numpy backend for
 * examples of how they should work.
 *   - EwiseMul, ScalarMul
 *   - EwiseDiv, ScalarDiv
 *   - ScalarPower
 *   - EwiseMaximum, ScalarMaximum
 *   - EwiseEq, ScalarEq
 *   - EwiseGe, ScalarGe
 *   - EwiseLog
 *   - EwiseExp
 *   - EwiseTanh
 *
 * If you implement all these naively, there will be a lot of repeated code, so
 * you are welcome (but not required), to use macros or templates to define these
 * functions (however you want to do so, as long as the functions match the proper)
 * signatures above.
 */

/// BEGIN YOUR SOLUTION
// EwiseMul, ScalarMul
__global__ void EwiseMulKernel(const scalar_t* a, const scalar_t* b, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = a[gid] * b[gid];
  }
}

void EwiseMul(const CudaArray& a, const CudaArray& b, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  EwiseMulKernel<<<dim.grid, dim.block>>>(a.ptr, b.ptr, out->ptr, out->size);
}

__global__ void ScalarMulKernel(const scalar_t* a, scalar_t val, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = a[gid] * val;
  }
}

void ScalarMul(const CudaArray& a, scalar_t val, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  ScalarMulKernel<<<dim.grid, dim.block>>>(a.ptr, val, out->ptr, out->size);
}

// EwiseDiv, ScalarDiv
__global__ void EwiseDivKernel(const scalar_t* a, const scalar_t* b, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = a[gid] / b[gid];
  }
}

void EwiseDiv(const CudaArray& a, const CudaArray& b, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  EwiseDivKernel<<<dim.grid, dim.block>>>(a.ptr, b.ptr, out->ptr, out->size);
}

__global__ void ScalarDivKernel(const scalar_t* a, scalar_t val, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = a[gid] / val;
  }
}

void ScalarDiv(const CudaArray& a, scalar_t val, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  ScalarDivKernel<<<dim.grid, dim.block>>>(a.ptr, val, out->ptr, out->size);
}

// ScalarPower
__global__ void ScalarPowerKernel(const scalar_t* a, scalar_t val, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = pow(a[gid], val);
  }
}

void ScalarPower(const CudaArray& a, scalar_t val, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  ScalarPowerKernel<<<dim.grid, dim.block>>>(a.ptr, val, out->ptr, out->size);
}

// EwiseMaximum, ScalarMaximum
__global__ void EwiseMaximumKernel(const scalar_t* a, const scalar_t* b, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    if (a[gid] < b[gid]) {
      out[gid] = b[gid];
    } else {
      out[gid] = a[gid];
    }
  }
}

void EwiseMaximum(const CudaArray& a, const CudaArray& b, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  EwiseMaximumKernel<<<dim.grid, dim.block>>>(a.ptr, b.ptr, out->ptr, out->size);
}

__global__ void ScalarMaximumKernel(const scalar_t* a, scalar_t val, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    if (a[gid] < val) {
      out[gid] = val;
    } else {
      out[gid] = a[gid];
    }
  }
}

void ScalarMaximum(const CudaArray& a, scalar_t val, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  ScalarMaximumKernel<<<dim.grid, dim.block>>>(a.ptr, val, out->ptr, out->size);
}

// EwiseEq, ScalarEq
__global__ void EwiseEqKernel(const scalar_t* a, const scalar_t* b, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = scalar_t(a[gid] - b[gid] < 1e-6 && a[gid] - b[gid] > -1e-6);
  }
}

void EwiseEq(const CudaArray& a, const CudaArray& b, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  EwiseEqKernel<<<dim.grid, dim.block>>>(a.ptr, b.ptr, out->ptr, out->size);
}

__global__ void ScalarEqKernel(const scalar_t* a, scalar_t val, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = scalar_t(a[gid] - val < 1e-6 && a[gid] - val > -1e-6);
  }
}

void ScalarEq(const CudaArray& a, scalar_t val, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  ScalarEqKernel<<<dim.grid, dim.block>>>(a.ptr, val, out->ptr, out->size);
}

// EwiseGe, ScalarGe
__global__ void EwiseGeKernel(const scalar_t* a, const scalar_t* b, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = a[gid] >= b[gid];
  }
}

void EwiseGe(const CudaArray& a, const CudaArray& b, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  EwiseGeKernel<<<dim.grid, dim.block>>>(a.ptr, b.ptr, out->ptr, out->size);
}

__global__ void ScalarGeKernel(const scalar_t* a, scalar_t val, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = a[gid] >= val;
  }
}

void ScalarGe(const CudaArray& a, scalar_t val, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  ScalarGeKernel<<<dim.grid, dim.block>>>(a.ptr, val, out->ptr, out->size);
}

// EwiseLog
__global__ void SingleEwiseLogKernel(const scalar_t* a, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = log(a[gid]);
  }
}

void EwiseLog(const CudaArray& a, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  SingleEwiseLogKernel<<<dim.grid, dim.block>>>(a.ptr, out->ptr, out->size);
}

// EwiseExp
__global__ void SingleEwiseExpKernel(const scalar_t* a, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = exp(a[gid]);
  }
}

void EwiseExp(const CudaArray& a, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  SingleEwiseExpKernel<<<dim.grid, dim.block>>>(a.ptr, out->ptr, out->size);
}

// EwiseTanh
__global__ void SingleEwiseTanhKernel(const scalar_t* a, scalar_t* out, size_t size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    out[gid] = tanh(a[gid]);
  }
}

void EwiseTanh(const CudaArray& a, CudaArray* out) {
  CudaDims dim = CudaOneDim(out->size);
  SingleEwiseTanhKernel<<<dim.grid, dim.block>>>(a.ptr, out->ptr, out->size);
}
/// END YOUR SOLUTION

Part 8: CUDA Backend - Reductions

////////////////////////////////////////////////////////////////////////////////
// Max and sum reductions
////////////////////////////////////////////////////////////////////////////////
__global__ void ReduceMaxKernel(const scalar_t* a, scalar_t* out, size_t size, size_t reduce_size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    int j = gid * reduce_size;
    out[gid] = a[j];
    for (int k = 0; k < reduce_size; k++) {
      if (out[gid] < a[j + k]) {
        out[gid] = a[j + k];
      }
    }
  }
}

void ReduceMax(const CudaArray& a, CudaArray* out, size_t reduce_size) {
  /**
   * Reduce by taking maximum over `reduce_size` contiguous blocks.  Even though it is inefficient,
   * for simplicity you can perform each reduction in a single CUDA thread.
   * 
   * Args:
   *   a: compact array of size a.size = out.size * reduce_size to reduce over
   *   out: compact array to write into
   *   redice_size: size of the dimension to reduce over
   */
  /// BEGIN YOUR SOLUTION
  CudaDims dim = CudaOneDim(out->size);
  ReduceMaxKernel<<<dim.grid, dim.block>>>(a.ptr, out->ptr, out->size, reduce_size);
  /// END YOUR SOLUTION
}


__global__ void ReduceSumKernel(const scalar_t* a, scalar_t* out, size_t size, size_t reduce_size) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < size) {
    int j = gid * reduce_size;
    out[gid] = 0;
    for (int k = 0; k < reduce_size; k++) {
      out[gid] += a[j + k];
    }
  }
}

void ReduceSum(const CudaArray& a, CudaArray* out, size_t reduce_size) {
  /**
   * Reduce by taking summation over `reduce_size` contiguous blocks.  Again, for simplicity you 
   * can perform each reduction in a single CUDA thread.
   * 
   * Args:
   *   a: compact array of size a.size = out.size * reduce_size to reduce over
   *   out: compact array to write into
   *   redice_size: size of the dimension to reduce over
   */
  /// BEGIN YOUR SOLUTION
  CudaDims dim = CudaOneDim(out->size);
  ReduceSumKernel<<<dim.grid, dim.block>>>(a.ptr, out->ptr, out->size, reduce_size);
  /// END YOUR SOLUTION
}

Part 9: CUDA Backend - Matrix multiplication

////////////////////////////////////////////////////////////////////////////////
// Elementwise and scalar operations
////////////////////////////////////////////////////////////////////////////////

__global__ void MatmulKernel(const scalar_t* a, const scalar_t* b, scalar_t* out, uint32_t M, uint32_t N,
            uint32_t P) {
  size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  if (gid < M) {
    for (int j = 0; j < P; j++) {
      int r = gid * P + j;
      out[r] = 0;
      for (int k = 0; k < N; k++) {
        int s = gid * N + k;
        int t = k * P + j;
        out[r] += a[s] * b[t];
      }
    }
  }
}

void Matmul(const CudaArray& a, const CudaArray& b, CudaArray* out, uint32_t M, uint32_t N,
            uint32_t P) {
  /**
   * Multiply two (compact) matrices into an output (also comapct) matrix.  You will want to look
   * at the lecture and notes on GPU-based linear algebra to see how to do this.  Since ultimately
   * mugrade is just evaluating correctness, you _can_ implement a version that simply parallelizes
   * over (i,j) entries in the output array.  However, to really get the full benefit of this
   * problem, we would encourage you to use cooperative fetching, shared memory register tiling, 
   * and other ideas covered in the class notes.  Note that unlike the tiled matmul function in
   * the CPU backend, here you should implement a single function that works across all size
   * matrices, whether or not they are a multiple of a tile size.  As with previous CUDA
   * implementations, this function here will largely just set up the kernel call, and you should
   * implement the logic in a separate MatmulKernel() call.
   * 
   *
   * Args:
   *   a: compact 2D array of size m x n
   *   b: comapct 2D array of size n x p
   *   out: compact 2D array of size m x p to write the output to
   *   M: rows of a / out
   *   N: columns of a / rows of b
   *   P: columns of b / out
   */

  /// BEGIN YOUR SOLUTION
  CudaDims dim = CudaOneDim(M);
  MatmulKernel<<<dim.grid, dim.block>>>(a.ptr, b.ptr, out->ptr, M, N, P);
  /// END YOUR SOLUTION
}