// Catmull-Rom code originally taken from https://www.npmjs.com/package/cat-rom-spline

import { add, distance, scale } from "../../../../backend/src/shared/math-utils";

export function calculateCatmullRomCurveThroughPoints(
  points: number[],
  samplesPerSegment = 1000,
  tension = 0.5
): number[] {
  const len = points.length;
  const firstVector = [points[2] - points[0], points[3] - points[1]];
  const lastVector = [points[len - 4] - points[len - 2], points[len - 3] - points[len - 1]];

  const interpolatedFirstPoint = [points[0] - firstVector[0], points[1] - firstVector[1]];
  const interpolatedLastPoint = [points[len - 2] - lastVector[0], points[len - 1] - lastVector[1]];

  const finalPoints = [...interpolatedFirstPoint, ...points, ...interpolatedLastPoint];

  const results: number[] = [];

  for (let i = 0; i < finalPoints.length - 7; i += 2) {
    const p0 = [finalPoints[i], finalPoints[i + 1]];
    const p1 = [finalPoints[i + 2], finalPoints[i + 3]];
    const p2 = [finalPoints[i + 4], finalPoints[i + 5]];
    const p3 = [finalPoints[i + 6], finalPoints[i + 7]];

    results.push(
      ...p1,
      ...catmullRomSampleSegment(p0, p1, p2, p3, samplesPerSegment, tension).flat()
    );

    if (i + 8 === finalPoints.length) {
      results.push(...p2);
    }
  }

  return results;
}

function catmullRomSampleSegment(
  p0: number[],
  p1: number[],
  p2: number[],
  p3: number[],
  samples: number,
  tension: number
): number[][] {
  const points = [];
  const segmentDist = distance(p1, p2);

  const t0 = 0;
  const t1 = Math.pow(distance(p0, p1), tension);
  const t2 = Math.pow(segmentDist, tension) + t1;
  const t3 = Math.pow(distance(p2, p3), tension) + t2;

  const sampleStep = (t2 - t1) / samples;

  for (let i = 0; i < samples; i++) {
    const t = t1 + i * sampleStep;
    points.push(catmullRomSamplePoint(p0, p1, p2, p3, t0, t1, t2, t3, t));
  }

  return points;
}

function catmullRomSamplePoint(
  p0: number[],
  p1: number[],
  p2: number[],
  p3: number[],
  t0: number,
  t1: number,
  t2: number,
  t3: number,
  t: number
): number[] {
  const a1 = add(scale(p0, (t1 - t) / (t1 - t0)), scale(p1, (t - t0) / (t1 - t0)));
  const a2 = add(scale(p1, (t2 - t) / (t2 - t1)), scale(p2, (t - t1) / (t2 - t1)));
  const a3 = add(scale(p2, (t3 - t) / (t3 - t2)), scale(p3, (t - t2) / (t3 - t2)));

  const b1 = add(scale(a1, (t2 - t) / (t2 - t0)), scale(a2, (t - t0) / (t2 - t0)));
  const b2 = add(scale(a2, (t3 - t) / (t3 - t1)), scale(a3, (t - t1) / (t3 - t1)));

  return add(scale(b1, (t2 - t) / (t2 - t1)), scale(b2, (t - t1) / (t2 - t1)));
}
