import { broadcast_shapes, asarray, tester, array, empty, broadcast_to, get_size, NDArray } from './core.mjs';
/**
* @param {...NDArray} arrays
* @returns {Broadcast}
*/
export function broadcast(...arrays) {
return new Broadcast(arrays.map(a => asarray(a)));
}
/**
* @class
*/
export class Broadcast {
/**
* @param {NDArray[]} arrays
*/
constructor(arrays) {
/** @member {number[]} */
this.shape = broadcast_shapes(...arrays.map(array => array.shape));
/** @member {NDArray[]} */
this.arrays = arrays.map(array => broadcast_to(array, this.shape));
/** @member {number} */
this.ndim = this.shape.length;
/** @member {number} */
this.size = get_size(this.shape);
this.reset();
}
[Symbol.iterator]() {
if (this.index != 0) this.reset();
return this;
}
/**
* @typedef {Object} BroadcastResult
* @property {any} value
* @property {boolean} done
*/
/**
* @returns {BroadcastResult}
*/
next() {
let value = this.iters.map(iter => iter.next().value);
let done = this.index >= this.size;
this.index++;
return { value, done };
}
reset() {
/** @member {Flatiter[]} */
this.iters = this.arrays.map(array => array.flat);
/** @member {number} */
this.index = 0;
}
}
process.env.PRODUCTION ||
tester
.add(
broadcast,
() => {
`
out = np.empty(b.shape)
out.flat = [u+v for (u,v) in b]
out
array([[5., 6., 7.],
[6., 7., 8.],
[7., 8., 9.]])
`;
let x = array([[1], [2], [3]]),
y = array([4, 5, 6]),
b = broadcast(x, y);
let out = empty(b.shape);
let flat = [];
for (let [u, v] of b) {
flat.push(u + v);
}
out.flat = flat;
return out;
},
() =>
array([
[5, 6, 7],
[6, 7, 8],
[7, 8, 9],
])
)
.add(
broadcast,
() => {
`
>>> x = np.array([1, 2, 3])
>>> y = np.array([[4], [5], [6]])
>>> b = np.broadcast(x, y)
>>> b.index
0
>>> next(b), next(b), next(b)
((1, 4), (2, 4), (3, 4))
>>> b.index
3
>>> b.reset()
>>> b.index
0
`;
let out = [];
let x = array([1, 2, 3]),
y = array([[4], [5], [6]]),
b = broadcast(x, y);
out.push(b.index);
out.push(b.next().value, b.next().value, b.next().value);
out.push(b.index);
b.reset();
out.push(b.index);
return out;
},
() => [0, [1, 4], [2, 4], [3, 4], 3, 0]
);