import {
tester,
arange,
array,
normalize_axis_index,
asarray,
array_split,
ndim,
NDArray,
empty,
} from './core.mjs';
/**
*
* @param {NDArray} ary
* @param {NDArray} indices_or_sections
* @param {number} [axis]
* @returns {NDArray}
*/
export function split(ary, indices_or_sections, axis = 0) {
ary = asarray(ary);
axis = normalize_axis_index(axis, ary.ndim);
if (indices_or_sections.shape != null) indices_or_sections = indices_or_sections.array();
if (indices_or_sections.length == undefined) {
let sections = indices_or_sections;
let N = ary.shape[axis];
if (N % sections) throw `array split does not result in an equal division`;
}
return array_split(ary, indices_or_sections, axis);
}
export function dsplit(ary, indices_or_sections) {
if (ndim(ary) < 3) throw `dsplit only works on arrays of 3 or more dimensions`;
return split(ary, indices_or_sections, 2);
}
export function hsplit(ary, indices_or_sections) {
if (ndim(ary) == 0) throw `hsplit only works on arrays of 1 or more dimensions`;
return split(ary, indices_or_sections, ndim(ary) > 1 ? 1 : 0);
}
export function vsplit(ary, indices_or_sections) {
if (ndim(ary) < 2) throw `vsplit only works on arrays of 2 or more dimensions`;
return split(ary, indices_or_sections, 0);
}
process.env.PRODUCTION ||
tester
.add(
split,
() => split(arange(9.0), 3),
() => [array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8])]
)
.add(
split,
() => split(arange(8.0), [3, 5, 6, 10]),
() => [array([0, 1, 2]), array([3, 4]), array([5]), array([6, 7]), array([])]
);
process.env.PRODUCTION ||
tester
.add(
dsplit,
() => dsplit(arange(16.0).reshape(2, 2, 4), 2),
() => [
array([
[
[0, 1],
[4, 5],
],
[
[8, 9],
[12, 13],
],
]),
array([
[
[2, 3],
[6, 7],
],
[
[10, 11],
[14, 15],
],
]),
]
)
.add(
dsplit,
() => dsplit(arange(16.0).reshape(2, 2, 4), array([3, 6])),
() => [
array([
[
[0, 1, 2],
[4, 5, 6],
],
[
[8, 9, 10],
[12, 13, 14],
],
]),
array([
[[3], [7]],
[[11], [15]],
]),
empty([2, 2, 0]),
]
);
process.env.PRODUCTION ||
tester
.add(
hsplit,
() => hsplit(arange(16.0).reshape(4, 4), 2),
() => [
array([
[0, 1],
[4, 5],
[8, 9],
[12, 13],
]),
array([
[2, 3],
[6, 7],
[10, 11],
[14, 15],
]),
]
)
.add(
hsplit,
() => hsplit(arange(16.0).reshape(4, 4), array([3, 6])),
() => [
array([
[0, 1, 2],
[4, 5, 6],
[8, 9, 10],
[12, 13, 14],
]),
array([[3], [7], [11], [15]]),
empty([4, 0]),
]
)
.add(
hsplit,
() => hsplit(arange(8.0).reshape(2, 2, 2), 2),
() => [array([[[0, 1]], [[4, 5]]]), array([[[2, 3]], [[6, 7]]])]
)
.add(
hsplit,
() => hsplit(array([0, 1, 2, 3, 4, 5]), 2),
() => [array([0, 1, 2]), array([3, 4, 5])]
);
process.env.PRODUCTION ||
tester
.add(
vsplit,
() => vsplit(arange(16.0).reshape(4, 4), 2),
() => [
array([
[0, 1, 2, 3],
[4, 5, 6, 7],
]),
array([
[8, 9, 10, 11],
[12, 13, 14, 15],
]),
]
)
.add(
vsplit,
() => vsplit(arange(16.0).reshape(4, 4), array([3, 6])),
() => [
array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
]),
array([[12, 13, 14, 15]]),
empty([0, 4]),
]
)
.add(
vsplit,
() => vsplit(arange(8.0).reshape(2, 2, 2), 2),
() => [
array([
[
[0, 1],
[2, 3],
],
]),
array([
[
[4, 5],
[6, 7],
],
]),
]
);