import { CSSProperties, ReactElement, RefObject, useCallback, useEffect, useRef, useState, } from "react" import { add, angle, between, distance, middlePos, minus, mul, Pos, posWithinBase, ratioWithinBase, } from "./Pos" import "../../style/bendable_arrows.css" import Draggable from "react-draggable" export interface BendableArrowProps { area: RefObject startPos: Pos segments: Segment[] onSegmentsChanges: (edges: Segment[]) => void forceStraight: boolean startRadius?: number endRadius?: number onDeleteRequested?: () => void style?: ArrowStyle } export interface ArrowStyle { width?: number dashArray?: string head?: () => ReactElement tail?: () => ReactElement } const ArrowStyleDefaults = { width: 4, } export interface Segment { next: Pos controlPoint?: Pos } function constraintInCircle(pos: Pos, from: Pos, radius: number): Pos { const theta = angle(pos, from) return { x: pos.x - Math.sin(theta) * radius, y: pos.y - Math.cos(theta) * radius, } } export default function BendableArrow({ area, startPos, segments, onSegmentsChanges, forceStraight, style, startRadius = 0, endRadius = 0, onDeleteRequested, }: BendableArrowProps) { const containerRef = useRef(null) const svgRef = useRef(null) const pathRef = useRef(null) const styleWidth = style?.width ?? ArrowStyleDefaults.width useEffect(() => { setInternalSegments(segments) }, [segments]) const [internalSegments, setInternalSegments] = useState(segments) const [isSelected, setIsSelected] = useState(false) const headRef = useRef(null) const tailRef = useRef(null) function computeControlPoints(parentBase: DOMRect) { return segments.flatMap(({ next, controlPoint }, i) => { const prev = i == 0 ? startPos : segments[i - 1].next const prevRelative = posWithinBase(prev, parentBase) const nextRelative = posWithinBase(next, parentBase) const cpPos = controlPoint || ratioWithinBase( add(between(prevRelative, nextRelative), parentBase), parentBase, ) const setControlPointPos = (newPos: Pos | undefined) => { const segment = segments[i] const newSegments = segments.toSpliced(i, 1, { ...segment, controlPoint: newPos, }) onSegmentsChanges(newSegments) } return [ // curve control point setControlPointPos(undefined)} onMoves={(controlPoint) => { setInternalSegments((is) => { return is.toSpliced(i, 1, { ...is[i], controlPoint, }) }) }} />, //next pos point (only if this is not the last segment) i != segments.length - 1 && ( { const currentSegment = segments[i] const newSegments = segments.toSpliced(i, 1, { ...currentSegment, next, }) onSegmentsChanges(newSegments) }} onRemove={() => { onSegmentsChanges( segments.toSpliced(Math.max(i - 1, 0), 1), ) }} onMoves={(next) => { setInternalSegments((is) => { return is.toSpliced(i, 1, { ...is[i], next, }) }) }} /> ), ] }) } const update = useCallback(() => { const parentBase = area.current!.getBoundingClientRect() const firstSegment = internalSegments[0] ?? null if (firstSegment == null) throw new Error("segments might not be empty.") const lastSegment = internalSegments[internalSegments.length - 1] const startRelative = posWithinBase(startPos, parentBase) const endRelative = posWithinBase(lastSegment.next, parentBase) const startNext = firstSegment.controlPoint && !forceStraight ? posWithinBase(firstSegment.controlPoint, parentBase) : posWithinBase(firstSegment.next, parentBase) const endPrevious = forceStraight ? startRelative : lastSegment.controlPoint ? posWithinBase(lastSegment.controlPoint, parentBase) : internalSegments[internalSegments.length - 2] ? posWithinBase( internalSegments[internalSegments.length - 2].next, parentBase, ) : startRelative const tailPos = constraintInCircle( startRelative, startNext, startRadius!, ) const headPos = constraintInCircle(endRelative, endPrevious, endRadius!) const left = Math.min(tailPos.x, headPos.x) const top = Math.min(tailPos.y, headPos.y) Object.assign(tailRef.current!.style, { left: tailPos.x + "px", top: tailPos.y + "px", transformOrigin: "top center", transform: `translateX(-50%) rotate(${ -angle(tailPos, startNext) * (180 / Math.PI) }deg)`, } as CSSProperties) Object.assign(headRef.current!.style, { left: headPos.x + "px", top: headPos.y + "px", transformOrigin: "top center", transform: `translateX(-50%) rotate(${ -angle(headPos, endPrevious) * (180 / Math.PI) }deg)`, } as CSSProperties) const svgStyle: CSSProperties = { left: left + "px", top: top + "px", } const segmentsRelatives = ( forceStraight ? internalSegments.slice(-1) : internalSegments ).map(({ next, controlPoint }, idx) => { const nextPos = posWithinBase(next, parentBase) return { next: nextPos, cp: controlPoint && !forceStraight ? posWithinBase(controlPoint, parentBase) : between( idx == 0 ? startRelative : posWithinBase( internalSegments[idx - 1].next, parentBase, ), nextPos, ), } }) const computedSegments = segmentsRelatives .map(({ next: n, cp }, idx) => { let next = n if (idx == internalSegments.length - 1) { //if it is the last element next = constraintInCircle(next, cp, endRadius!) } return `C${cp.x - left} ${cp.y - top}, ${cp.x - left} ${ cp.y - top }, ${next.x - left} ${next.y - top}` }) .join(" ") const d = `M${tailPos.x - left} ${tailPos.y - top} ` + computedSegments pathRef.current!.setAttribute("d", d) Object.assign(svgRef.current!.style, svgStyle) }, [startPos, internalSegments, forceStraight]) useEffect(update, [update]) useEffect(() => { const selectionHandler = (e: MouseEvent) => { if (!(e.target instanceof Node)) return const isSelected = containerRef.current!.contains(e.target) setIsSelected(isSelected) } document.addEventListener("mousedown", selectionHandler) window.addEventListener("resize", update) return () => { document.removeEventListener("mousedown", selectionHandler) window.removeEventListener("resize", update) } }, [update, containerRef]) useEffect(() => { if (forceStraight) return const addSegment = (e: MouseEvent) => { const parentBase = area.current!.getBoundingClientRect() const clickAbsolutePos: Pos = { x: e.x, y: e.y } const clickPosBaseRatio = ratioWithinBase( clickAbsolutePos, parentBase, ) let segmentInsertionIndex = -1 let segmentInsertionIsOnRightOfCP = false for (let i = 0; i < segments.length; i++) { const segment = segments[i] let currentPos = i == 0 ? startPos : segments[i - 1].next let nextPos = segment.next let controlPointPos = segment.controlPoint ? segment.controlPoint : between(currentPos, nextPos) const result = searchOnSegment( currentPos, controlPointPos, nextPos, clickPosBaseRatio, 0.05, ) if (result == PointSegmentSearchResult.NOT_FOUND) continue segmentInsertionIndex = i segmentInsertionIsOnRightOfCP = result == PointSegmentSearchResult.RIGHT_TO_CONTROL_POINT break } if (segmentInsertionIndex == -1) return const splicedSegment: Segment = segments[segmentInsertionIndex] onSegmentsChanges( segments.toSpliced( segmentInsertionIndex, 1, { next: clickPosBaseRatio, controlPoint: segmentInsertionIsOnRightOfCP ? splicedSegment.controlPoint : undefined, }, { next: splicedSegment.next, controlPoint: segmentInsertionIsOnRightOfCP ? undefined : splicedSegment.controlPoint, }, ), ) } pathRef?.current?.addEventListener("dblclick", addSegment) return () => { pathRef?.current?.removeEventListener("dblclick", addSegment) } }, [pathRef, segments, onSegmentsChanges]) return (
{ if (onDeleteRequested && e.key == "Delete") onDeleteRequested() }} />
{style?.head?.call(style)}
{style?.tail?.call(style)}
{!forceStraight && isSelected && computeControlPoints(area.current!.getBoundingClientRect())}
) } interface ControlPointProps { className: string posRatio: Pos parentBase: DOMRect onMoves: (currentPos: Pos) => void onPosValidated: (newPos: Pos) => void onRemove: () => void radius?: number } enum PointSegmentSearchResult { LEFT_TO_CONTROL_POINT, RIGHT_TO_CONTROL_POINT, NOT_FOUND, } function searchOnSegment( startPos: Pos, controlPoint: Pos, endPos: Pos, point: Pos, minDistance: number, ): PointSegmentSearchResult { const step = 1 / ((distance(startPos, controlPoint) + distance(controlPoint, endPos)) / minDistance) const p0MinusP1 = minus(startPos, controlPoint) const p2MinusP1 = minus(endPos, controlPoint) function getDistanceAt(t: number): number { // apply the bezier function const pos = add( add(controlPoint, mul(p0MinusP1, (1 - t) ** 2)), mul(p2MinusP1, t ** 2), ) return distance(pos, point) } for (let t = 0; t < 1; t += step) { if (getDistanceAt(t) <= minDistance) return t >= 0.5 ? PointSegmentSearchResult.RIGHT_TO_CONTROL_POINT : PointSegmentSearchResult.LEFT_TO_CONTROL_POINT } return PointSegmentSearchResult.NOT_FOUND } let t = 0 let slice = 0.5 for (let i = 0; i < 100; i++) { t += slice slice /= 2 // console.log(t) } function ArrowPoint({ className, posRatio, parentBase, onMoves, onPosValidated, onRemove, radius = 7, }: ControlPointProps) { const ref = useRef(null) const pos = posWithinBase(posRatio, parentBase) return ( { const pointPos = middlePos(ref.current!.getBoundingClientRect()) onPosValidated(ratioWithinBase(pointPos, parentBase)) }} onDrag={() => { const pointPos = middlePos(ref.current!.getBoundingClientRect()) onMoves(ratioWithinBase(pointPos, parentBase)) }} position={{ x: pos.x - radius, y: pos.y - radius }}>
{ if (e.key == "Delete") { onRemove() } }} tabIndex={0} /> ) }