Source: dot.mjs

import {
	tester,
	arange,
	array,
	asarray,
	ones,
	zeros,
	slice,
	NDArray,
	amax,
	multiply,
	empty,
	shallow_array_equal,
	array_equal,
	ndoffset,
	sum,
	matmul,
	ndindex,
} from './core.mjs';

export function assert(cond, msg) {
	if (!cond) {
		throw new Error(msg);
	}
}

/**
 * @param {NDArray} a
 * @param {NDArray} b
 * @param {null|NDArray} [out]
 * @returns {NDArray}
 */
export function dot(a, b, out = null) {
	a = asarray(a);
	b = asarray(b);

	if (a.ndim == 1 && b.ndim == 1) return sum(multiply(a, b), 0, out);

	if (a.ndim == 2 && b.ndim == 2) return matmul(a, b, out);

	if (a.ndim == 0 || b.ndim == 0) return multiply(a, b, out);

	if (b.ndim == 1) {
		assert(a.shape.at(-1) == b.shape[0], `shapes ${a.shape} and ${b.shape} not aligned`);

		let shape = a.shape.slice(0, -1);
		if (out == null) out = empty(shape);
		else assert(shallow_array_equal(shape, out.shape), `out shape mismatch`);

		for (let index of ndindex(shape)) {
			dot(a.get(index), b, out.at(slice('...'), ...index));
		}

		return out;
	}
}

process.env.PRODUCTION ||
	tester
		.add(
			dot,
			() => dot(3, 4),
			() => 12
		)
		.add(
			dot,
			() =>
				dot(
					[
						[1, 2],
						[3, 4],
					],
					[4, 1]
				),
			() => array([6, 16])
		)
		.add(
			dot,
			() =>
				dot(
					[
						[1, 0],
						[0, 1],
					],
					[
						[4, 1],
						[2, 2],
					]
				),
			() =>
				array([
					[4, 1],
					[2, 2],
				])
		);