Source: diagonal.mjs

import { tester, arange, array, asarray, slice, NDArray, normalize_axis_index, transpose } from './core.mjs';

/**
 * @param {NDArray} a
 * @param {number} [offset]
 * @param {number} [axis1]
 * @param {number} [axis2]
 * @returns {NDArray}
 */
export function diagonal(a, offset = 0, axis1 = 0, axis2 = 1) {
	a = asarray(a);
	let { ndim } = a;
	if (ndim < 2) throw `array.ndim must be >= 2`;

	axis1 = normalize_axis_index(axis1, ndim);
	axis2 = normalize_axis_index(axis2, ndim);

	let axes = Array(ndim);
	axes[ndim - 2] = axis1;
	axes[ndim - 1] = axis2;
	for (let i = 0, j = 0; i < ndim; i++) {
		if (i == axis1 || i == axis2) continue;
		axes[j++] = i;
	}

	a = transpose(a, axes);

	// https://github.com/numpy/numpy-refactor/blob/6de313865ec3f49bcdd06ccbc879f27e65acf818/numpy/core/src/multiarray/item_selection.c
	// view only
	// writable, no need to d.setflags(write=True)

	let n1 = a.shape[ndim - 2];
	let n2 = a.shape[ndim - 1];
	let step = n2 + 1;
	let start, stop;
	if (offset < 0) {
		start = -n2 * offset;
		stop = Math.min(n2, n1 + offset) * (n2 + 1) - n2 * offset;
	} else {
		start = offset;
		stop = Math.min(n1, n2 - offset) * (n2 + 1) + offset;
	}

	let count = Math.ceil((stop - start) / step);

	return a.as_strided(
		[...a.shape.slice(0, -2), count],
		[...a.strides.slice(0, -2), step * a.strides[ndim - 1]],
		a.offset + start
	);
}

process.env.PRODUCTION ||
	tester
		.add(
			diagonal,
			() => diagonal(arange(4).reshape(2, 2)),
			() => array([0, 3])
		)
		.add(
			diagonal,
			() => diagonal(arange(4).reshape(2, 2), 1),
			() => array([1])
		)
		.add(
			diagonal,
			() => diagonal(arange(8).reshape(2, 2, 2), 0, 0, 1),
			() =>
				array([
					[0, 6],
					[1, 7],
				])
		)
		.add(
			diagonal,
			() => diagonal(arange(8).reshape(2, 2, 2).at(slice(':'), slice(':'), 0), 0, 0, 1),
			() => array([0, 6])
		)
		.add(
			diagonal,
			() => diagonal(arange(8).reshape(2, 2, 2).at(slice(':'), slice(':'), 1), 0, 0, 1),
			() => array([1, 7])
		)
		.add(
			diagonal,
			() => {
				let a = arange(8);
				diagonal(a.reshape(2, 2, 2), 0, 0, 1).set(-1);
				return a;
			},
			() => array([-1, -1, 2, 3, 4, 5, -1, -1])
		);