import type { ComponentType, SyntheticEvent } from "react";
import {
  createContext,
  useCallback,
  useContext,
  useMemo,
  useState,
} from "react";
import type { UseFormReturn } from "react-hook-form";
import type { Path } from "react-hook-form";
import { z } from "zod";

interface Step<T = any> {
  name: string;
  component: ComponentType;
  schema?: T extends z.ZodType ? T : z.ZodType;
}

export const MultiStepFormContext = createContext<ReturnType<
  typeof useMultiStepForm
> | null>(null);

export function createSchemaFromSteps<T extends Step[] | readonly Step[]>(
  steps: T
) {
  return z.object(
    Object.fromEntries(
      steps
        .filter((step) => step.schema)
        .map((step) => [step.name, step.schema])
    )
  ) as z.ZodObject<{
    [K in T[number] as K["name"]]: K["schema"] extends z.ZodType
      ? K["schema"]
      : never;
  }>;
}

export function useMultiStepFormContext<Schema extends z.ZodType>() {
  const context = useContext(MultiStepFormContext) as ReturnType<
    typeof useMultiStepForm<Schema>
  >;

  if (!context) {
    throw new Error(
      "useMultiStepFormContext must be used within a MultiStepForm"
    );
  }

  return context;
}

export function useMultiStepForm<Schema extends z.ZodType>(
  form: UseFormReturn<z.infer<Schema>>,
  schema: Schema,
  steps: Step[] | readonly Step[],
  allowInvalidNavigation = false
) {
  const [currentStepIndex, setCurrentStepIndex] = useState(0);
  const [direction, setDirection] = useState<"forward" | "backward">();

  const isStepValid = useCallback(
    (index?: number) => {
      const currentStep = steps[index ?? currentStepIndex];

      if (!currentStep?.schema) {
        return true;
      }
    },
    [steps, currentStepIndex]
  );

  const getFieldsFromObject = useCallback(
    (schema: z.ZodObject<any> | z.ZodEffects<any>) => {
      if (schema instanceof z.ZodEffects) {
        const innerSchema = schema.innerType();

        if (innerSchema instanceof z.ZodObject) {
          return Object.keys((innerSchema as z.ZodObject<never>).shape);
        } else if (innerSchema instanceof z.ZodEffects) {
          return getFieldsFromObject(innerSchema.innerType());
        }

        return [];
      }

      return Object.keys((schema as z.ZodObject<never>).shape);
    },
    []
  );

  const getStepFields = useCallback(
    (index: number) => {
      const step = steps[index];

      if (!step?.schema) {
        return [];
      }

      return getFieldsFromObject(step.schema).map(
        (field) => `${step.name}.${field}`
      );
    },
    [schema, steps]
  );

  const validateStep = useCallback(
    (index: number) => {
      const keys = getStepFields(index);
      if (!keys) {
        return;
      }

      for (const key of keys) {
        void form.trigger(key as Path<z.TypeOf<Schema>>);
      }
    },
    [getStepFields, form]
  );

  const nextStep = useCallback(
    <Ev extends SyntheticEvent>(e: Ev) => {
      e.preventDefault();

      validateStep(currentStepIndex);

      if (!allowInvalidNavigation) {
        const isValid = isStepValid();

        if (!isValid) {
          return;
        }
      }

      if (currentStepIndex < steps.length - 1) {
        setDirection("forward");
        setCurrentStepIndex((prev) => prev + 1);
      }
    },
    [validateStep, allowInvalidNavigation, isStepValid, currentStepIndex, steps]
  );

  const prevStep = useCallback(
    <Ev extends SyntheticEvent>(e: Ev) => {
      e.preventDefault();

      if (currentStepIndex > 0) {
        setDirection("backward");
        setCurrentStepIndex((prev) => prev - 1);
      }
    },
    [currentStepIndex]
  );

  const goToStep = useCallback(
    (index: number) => {
      if (currentStepIndex === index) {
        return;
      }

      validateStep(currentStepIndex);

      if (
        index >= 0 &&
        index < steps.length &&
        (isStepValid() || allowInvalidNavigation)
      ) {
        setDirection(index > currentStepIndex ? "forward" : "backward");
        setCurrentStepIndex(index);
      }
    },
    [
      isStepValid,
      steps.length,
      currentStepIndex,
      validateStep,
      allowInvalidNavigation,
    ]
  );

  const invalidSteps = useMemo(() => {
    const formErrors = form.formState.errors;

    return Object.keys(formErrors).filter((key) => {
      const stepName = key.split(".")[0];
      return steps.some((step) => step.name === stepName);
    });
  }, [form.formState, steps]);

  return useMemo(
    () => ({
      form,
      currentStep: steps[currentStepIndex],
      currentStepIndex,
      totalSteps: steps.length,
      isFirstStep: currentStepIndex === 0,
      isLastStep: currentStepIndex === steps.length - 1,
      nextStep,
      prevStep,
      goToStep,
      direction,
      isStepValid,
      invalidSteps,
    }),
    [
      form,
      steps,
      currentStepIndex,
      nextStep,
      prevStep,
      goToStep,
      direction,
      isStepValid,
      invalidSteps,
    ]
  );
}
