From 07bc40ce75ac21d4ccffbd70ce2bef3a4ce784eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Ku=CC=88hn?= Date: Thu, 15 Apr 2021 21:14:33 +0200 Subject: [PATCH] wip: add getExtensionField --- .../Examples/CodeBlockLanguage/index.vue | 43 +---- packages/core/src/ExtensionManager.ts | 172 ++++++++++-------- .../src/helpers/createExtensionContext.ts | 101 ---------- .../helpers/getAttributesFromExtensions.ts | 36 +++- .../core/src/helpers/getExtensionField.ts | 25 +++ packages/core/src/helpers/getSchema.ts | 68 +++---- packages/core/src/index.ts | 2 +- packages/core/src/types.ts | 18 +- packages/core/src/utilities/callOrReturn.ts | 6 +- packages/extension-gapcursor/src/gapcursor.ts | 8 +- packages/extension-table/src/table.ts | 8 +- 11 files changed, 213 insertions(+), 274 deletions(-) delete mode 100644 packages/core/src/helpers/createExtensionContext.ts create mode 100644 packages/core/src/helpers/getExtensionField.ts diff --git a/docs/src/demos/Examples/CodeBlockLanguage/index.vue b/docs/src/demos/Examples/CodeBlockLanguage/index.vue index 6fe60eb8..b713b806 100644 --- a/docs/src/demos/Examples/CodeBlockLanguage/index.vue +++ b/docs/src/demos/Examples/CodeBlockLanguage/index.vue @@ -41,48 +41,13 @@ export default { Document, Paragraph, Text, - - Extension - .create({ - defaultOptions: { - foo: 'foo0', - }, - addProseMirrorPlugins() { - console.log(0, this.options) - return [] - }, - }) + CodeBlockLowlight .extend({ - // defaultOptions: { - // foo: 'foo1', - // }, - addProseMirrorPlugins() { - console.log(1, this.options) - // console.log(1, this.parentConfig.addProseMirrorPlugins) - return this.parentConfig.addProseMirrorPlugins?.() || [] - // return [] + addNodeView() { + return VueNodeViewRenderer(CodeBlockComponent) }, }) - .extend({ - // defaultOptions: { - // foo: 'foo2', - // }, - }) - // .extend({ - // addProseMirrorPlugins() { - // console.log(2, this.parentConfig.addProseMirrorPlugins) - // // return this.parentConfig.addProseMirrorPlugins?.() || [] - // return [] - // }, - // }) - - // CodeBlockLowlight - // .extend({ - // addNodeView() { - // return VueNodeViewRenderer(CodeBlockComponent) - // }, - // }) - // // .configure({ lowlight }), + .configure({ lowlight }), ], content: `

diff --git a/packages/core/src/ExtensionManager.ts b/packages/core/src/ExtensionManager.ts index 8f9df005..0b8cef7d 100644 --- a/packages/core/src/ExtensionManager.ts +++ b/packages/core/src/ExtensionManager.ts @@ -4,8 +4,8 @@ import { inputRules as inputRulesPlugin } from 'prosemirror-inputrules' import { EditorView, Decoration } from 'prosemirror-view' import { Plugin } from 'prosemirror-state' import { Editor } from './Editor' -import { Extensions, NodeViewRenderer, RawCommands } from './types' -import createExtensionContext from './helpers/createExtensionContext' +import { Extensions, RawCommands, AnyConfig } from './types' +import getExtensionField from './helpers/getExtensionField' import getSchema from './helpers/getSchema' import getSchemaTypeByName from './helpers/getSchemaTypeByName' import getNodeType from './helpers/getNodeType' @@ -13,6 +13,7 @@ import splitExtensions from './helpers/splitExtensions' import getAttributesFromExtensions from './helpers/getAttributesFromExtensions' import getRenderedAttributes from './helpers/getRenderedAttributes' import callOrReturn from './utilities/callOrReturn' +import { NodeConfig } from '.' export default class ExtensionManager { @@ -30,51 +31,51 @@ export default class ExtensionManager { this.schema = getSchema(this.extensions) this.extensions.forEach(extension => { - const context = createExtensionContext(extension, { + const context = { options: extension.options, editor: this.editor, type: getSchemaTypeByName(extension.config.name, this.schema), - }) + } if (extension.type === 'mark') { - const keepOnSplit = callOrReturn(extension.config.keepOnSplit, context) ?? true + const keepOnSplit = callOrReturn(getExtensionField(extension, 'keepOnSplit', context)) ?? true if (keepOnSplit) { this.splittableMarks.push(extension.config.name) } } - if (typeof extension.config.onBeforeCreate === 'function') { - this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context)) - } + // if (typeof extension.config.onBeforeCreate === 'function') { + // this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context)) + // } - if (typeof extension.config.onCreate === 'function') { - this.editor.on('create', extension.config.onCreate.bind(context)) - } + // if (typeof extension.config.onCreate === 'function') { + // this.editor.on('create', extension.config.onCreate.bind(context)) + // } - if (typeof extension.config.onUpdate === 'function') { - this.editor.on('update', extension.config.onUpdate.bind(context)) - } + // if (typeof extension.config.onUpdate === 'function') { + // this.editor.on('update', extension.config.onUpdate.bind(context)) + // } - if (typeof extension.config.onSelectionUpdate === 'function') { - this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context)) - } + // if (typeof extension.config.onSelectionUpdate === 'function') { + // this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context)) + // } - if (typeof extension.config.onTransaction === 'function') { - this.editor.on('transaction', extension.config.onTransaction.bind(context)) - } + // if (typeof extension.config.onTransaction === 'function') { + // this.editor.on('transaction', extension.config.onTransaction.bind(context)) + // } - if (typeof extension.config.onFocus === 'function') { - this.editor.on('focus', extension.config.onFocus.bind(context)) - } + // if (typeof extension.config.onFocus === 'function') { + // this.editor.on('focus', extension.config.onFocus.bind(context)) + // } - if (typeof extension.config.onBlur === 'function') { - this.editor.on('blur', extension.config.onBlur.bind(context)) - } + // if (typeof extension.config.onBlur === 'function') { + // this.editor.on('blur', extension.config.onBlur.bind(context)) + // } - if (typeof extension.config.onDestroy === 'function') { - this.editor.on('destroy', extension.config.onDestroy.bind(context)) - } + // if (typeof extension.config.onDestroy === 'function') { + // this.editor.on('destroy', extension.config.onDestroy.bind(context)) + // } }) } @@ -96,11 +97,11 @@ export default class ExtensionManager { get commands(): RawCommands { return this.extensions.reduce((commands, extension) => { - const context = createExtensionContext(extension, { + const context = { options: extension.options, editor: this.editor, type: getSchemaTypeByName(extension.config.name, this.schema), - }) + } if (!extension.config.addCommands) { return commands @@ -108,7 +109,7 @@ export default class ExtensionManager { return { ...commands, - ...extension.config.addCommands.bind(context)(), + ...getExtensionField(extension, 'addCommands', context)(), } }, {} as RawCommands) } @@ -117,22 +118,34 @@ export default class ExtensionManager { return [...this.extensions] .reverse() .map(extension => { - const context = createExtensionContext(extension, { + const context = { options: extension.options, editor: this.editor, type: getSchemaTypeByName(extension.config.name, this.schema), - }) + } const plugins: Plugin[] = [] - if (extension.config.addKeyboardShortcuts) { - const keyMapPlugin = keymap(extension.config.addKeyboardShortcuts.bind(context)()) + const addKeyboardShortcuts = getExtensionField( + extension, + 'addKeyboardShortcuts', + context, + ) + + if (addKeyboardShortcuts) { + const keyMapPlugin = keymap(addKeyboardShortcuts()) plugins.push(keyMapPlugin) } - if (this.editor.options.enableInputRules && extension.config.addInputRules) { - const inputRules = extension.config.addInputRules.bind(context)() + const addInputRules = getExtensionField( + extension, + 'addInputRules', + context, + ) + + if (this.editor.options.enableInputRules && addInputRules) { + const inputRules = addInputRules() const inputRulePlugins = inputRules.length ? [inputRulesPlugin({ rules: inputRules })] : [] @@ -140,46 +153,30 @@ export default class ExtensionManager { plugins.push(...inputRulePlugins) } - if (this.editor.options.enablePasteRules && extension.config.addPasteRules) { - const pasteRulePlugins = extension.config.addPasteRules.bind(context)() + const addPasteRules = getExtensionField( + extension, + 'addPasteRules', + context, + ) + + if (this.editor.options.enablePasteRules && addPasteRules) { + const pasteRulePlugins = addPasteRules() plugins.push(...pasteRulePlugins) } - // console.log('has pm', extension.config.addProseMirrorPlugins, extension) + const addProseMirrorPlugins = getExtensionField( + extension, + 'addProseMirrorPlugins', + context, + ) - const getItem = (rootext: any, ext: any, field: string): any => { - const realctx = createExtensionContext(ext, { - options: rootext.options, - // options: getItem(ext, 'defaultOptions'), - editor: this.editor, - type: getSchemaTypeByName(ext.config.name, this.schema), - }) + if (addProseMirrorPlugins) { + const proseMirrorPlugins = addProseMirrorPlugins() - if (ext.config[field]) { - if (typeof ext.config[field] === 'function') { - return ext.config[field].bind(realctx)() - } - - return ext.config[field] - } - - if (ext.parent) { - return getItem(rootext, ext.parent, field) - } - - return undefined + plugins.push(...proseMirrorPlugins) } - // console.log('get PM', getItem(extension, 'addProseMirrorPlugins', context)) - const realPMP = getItem(extension, extension, 'addProseMirrorPlugins') - - // if (extension.config.addProseMirrorPlugins) { - // const proseMirrorPlugins = extension.config.addProseMirrorPlugins.bind(context)() - - // plugins.push(...proseMirrorPlugins) - // } - return plugins }) .flat() @@ -194,15 +191,24 @@ export default class ExtensionManager { const { nodeExtensions } = splitExtensions(this.extensions) return Object.fromEntries(nodeExtensions - .filter(extension => !!extension.config.addNodeView) + .filter(extension => !!getExtensionField(extension, 'addNodeView')) .map(extension => { - const extensionAttributes = this.attributes.filter(attribute => attribute.type === extension.config.name) - const context = createExtensionContext(extension, { + const name = getExtensionField(extension, 'name') + const extensionAttributes = this.attributes.filter(attribute => attribute.type === name) + const context = { options: extension.options, editor, type: getNodeType(extension.config.name, this.schema), - }) - const renderer = extension.config.addNodeView?.call(context) as NodeViewRenderer + } + const addNodeView = getExtensionField( + extension, + 'addNodeView', + context, + ) + + if (!addNodeView) { + return [] + } const nodeview = ( node: ProsemirrorNode, @@ -212,7 +218,7 @@ export default class ExtensionManager { ) => { const HTMLAttributes = getRenderedAttributes(node, extensionAttributes) - return renderer({ + return addNodeView()({ editor, node, getPos, @@ -231,15 +237,21 @@ export default class ExtensionManager { const { nodeExtensions } = splitExtensions(this.extensions) return Object.fromEntries(nodeExtensions - .filter(extension => !!extension.config.renderText) + .filter(extension => !!getExtensionField(extension, 'renderText')) .map(extension => { - const context = createExtensionContext(extension, { + const context = { options: extension.options, editor, type: getNodeType(extension.config.name, this.schema), - }) + } - const textSerializer = (props: { node: ProsemirrorNode }) => extension.config.renderText?.call(context, props) + const renderText = getExtensionField(extension, 'renderText', context) + + if (!renderText) { + return [] + } + + const textSerializer = (props: { node: ProsemirrorNode }) => renderText(props) return [extension.config.name, textSerializer] })) diff --git a/packages/core/src/helpers/createExtensionContext.ts b/packages/core/src/helpers/createExtensionContext.ts deleted file mode 100644 index 675c0d52..00000000 --- a/packages/core/src/helpers/createExtensionContext.ts +++ /dev/null @@ -1,101 +0,0 @@ -import { AnyExtension, AnyObject } from '../types' - -// export default function createExtensionContext( -// extension: AnyExtension, -// data: T, -// ): T & { parentConfig: AnyObject } { -// const context: any = { -// ...data, -// // get parentConfig() { -// // return Object.fromEntries(Object.entries(extension.parentConfig).map(([key, value]) => { -// // if (typeof value !== 'function') { -// // return [key, value] -// // } - -// // console.log('call', key) - -// // return [key, value.bind(context)] -// // })) -// // }, - -// parentConfig: Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => { -// if (typeof value !== 'function') { -// return [key, value] -// } - -// // console.log('call', key) - -// return [key, value.bind(data)] -// })), - -// // get parentConfig() { -// // console.log('parent', extension.parent) -// // console.log('parent parent', extension.parent?.parent) - -// // return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => { -// // if (typeof value !== 'function') { -// // return [key, value] -// // } - -// // // console.log('call', key) - -// // return [key, value.bind(context)] -// // })) -// // }, - -// // parentConfig: null, -// } - -// return context -// } - -// export default function createExtensionContext( -// extension: AnyExtension, -// data: T, -// // @ts-ignore -// ): T & { parentConfig: AnyObject } { -// const context: any = data - -// if (!extension.parent) { -// context.parentConfig = {} - -// return context -// } - -// // const bla = { -// // ...(extension.parent.parent ? extension.parent.parent.config : {}), -// // ...extension.parent.config, -// // } - -// context.parentConfig = Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => { -// if (typeof value !== 'function') { -// return [key, value] -// } - -// // console.log('call', key) - -// return [key, value.bind(createExtensionContext(extension.parent, data))] -// })) - -// return context -// } - -export default function createExtensionContext( - extension: AnyExtension, - data: T, -): T & { parentConfig: AnyObject } { - const context: any = { - ...data, - get parentConfig() { - return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => { - if (typeof value !== 'function') { - return [key, value] - } - - return [key, value.bind(context)] - })) - }, - } - - return context -} diff --git a/packages/core/src/helpers/getAttributesFromExtensions.ts b/packages/core/src/helpers/getAttributesFromExtensions.ts index 467a6683..c70923e3 100644 --- a/packages/core/src/helpers/getAttributesFromExtensions.ts +++ b/packages/core/src/helpers/getAttributesFromExtensions.ts @@ -1,12 +1,14 @@ -import createExtensionContext from './createExtensionContext' import splitExtensions from './splitExtensions' +import getExtensionField from './getExtensionField' import { Extensions, GlobalAttributes, Attributes, Attribute, ExtensionAttribute, + AnyConfig, } from '../types' +import { NodeConfig, MarkConfig } from '..' /** * Get a list of all extension attributes defined in `addAttribute` and `addGlobalAttribute`. @@ -25,15 +27,22 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext } extensions.forEach(extension => { - const context = createExtensionContext(extension, { + const context = { options: extension.options, - }) + } - if (!extension.config.addGlobalAttributes) { + const addGlobalAttributes = getExtensionField( + extension, + 'addGlobalAttributes', + context, + ) + + if (!addGlobalAttributes) { return } - const globalAttributes = extension.config.addGlobalAttributes.bind(context)() as GlobalAttributes + // TODO: remove `as GlobalAttributes` + const globalAttributes = addGlobalAttributes() as GlobalAttributes globalAttributes.forEach(globalAttribute => { globalAttribute.types.forEach(type => { @@ -54,21 +63,28 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext }) nodeAndMarkExtensions.forEach(extension => { - const context = createExtensionContext(extension, { + const context = { options: extension.options, - }) + } - if (!extension.config.addAttributes) { + const addAttributes = getExtensionField( + extension, + 'addAttributes', + context, + ) + + if (!addAttributes) { return } - const attributes = extension.config.addAttributes.bind(context)() as Attributes + // TODO: remove `as Attributes` + const attributes = addAttributes() as Attributes Object .entries(attributes) .forEach(([name, attribute]) => { extensionAttributes.push({ - type: extension.config.name, + type: getExtensionField(extension, 'name'), name, attribute: { ...defaultAttribute, diff --git a/packages/core/src/helpers/getExtensionField.ts b/packages/core/src/helpers/getExtensionField.ts new file mode 100644 index 00000000..53065884 --- /dev/null +++ b/packages/core/src/helpers/getExtensionField.ts @@ -0,0 +1,25 @@ +import { AnyExtension, AnyObject, RemoveThis } from '../types' + +export default function getExtensionField( + extension: AnyExtension, + field: string, + context: AnyObject = {}, +): RemoveThis { + + if (extension.config[field] === undefined && extension.parent) { + return getExtensionField(extension.parent, field, context) + } + + if (typeof extension.config[field] === 'function') { + const value = extension.config[field].bind({ + ...context, + parent: extension.parent + ? getExtensionField(extension.parent, field, context) + : null, + }) + + return value + } + + return extension.config[field] +} diff --git a/packages/core/src/helpers/getSchema.ts b/packages/core/src/helpers/getSchema.ts index 4ad1ddde..05cc0074 100644 --- a/packages/core/src/helpers/getSchema.ts +++ b/packages/core/src/helpers/getSchema.ts @@ -1,13 +1,13 @@ import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model' import { Extensions } from '../types' import { ExtensionConfig, NodeConfig, MarkConfig } from '..' -import createExtensionContext from './createExtensionContext' import splitExtensions from './splitExtensions' import getAttributesFromExtensions from './getAttributesFromExtensions' import getRenderedAttributes from './getRenderedAttributes' import isEmptyObject from '../utilities/isEmptyObject' import injectExtensionAttributesToParseRule from './injectExtensionAttributesToParseRule' import callOrReturn from '../utilities/callOrReturn' +import getExtensionField from './getExtensionField' function cleanUpSchemaItem(data: T) { return Object.fromEntries(Object.entries(data).filter(([key, value]) => { @@ -46,9 +46,9 @@ export default function getSchema(extensions: Extensions): Schema { const nodes = Object.fromEntries(nodeExtensions.map(extension => { const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name) - const context = createExtensionContext(extension, { + const context = { options: extension.options, - }) + } const extraNodeFields = nodeSchemaExtenders.reduce((fields, nodeSchemaExtender) => { const extraFields = callOrReturn(nodeSchemaExtender, context, extension) @@ -61,29 +61,32 @@ export default function getSchema(extensions: Extensions): Schema { const schema: NodeSpec = cleanUpSchemaItem({ ...extraNodeFields, - content: callOrReturn(extension.config.content, context), - marks: callOrReturn(extension.config.marks, context), - group: callOrReturn(extension.config.group, context), - inline: callOrReturn(extension.config.inline, context), - atom: callOrReturn(extension.config.atom, context), - selectable: callOrReturn(extension.config.selectable, context), - draggable: callOrReturn(extension.config.draggable, context), - code: callOrReturn(extension.config.code, context), - defining: callOrReturn(extension.config.defining, context), - isolating: callOrReturn(extension.config.isolating, context), + content: callOrReturn(getExtensionField(extension, 'content', context)), + marks: callOrReturn(getExtensionField(extension, 'marks', context)), + group: callOrReturn(getExtensionField(extension, 'group', context)), + inline: callOrReturn(getExtensionField(extension, 'inline', context)), + atom: callOrReturn(getExtensionField(extension, 'atom', context)), + selectable: callOrReturn(getExtensionField(extension, 'selectable', context)), + draggable: callOrReturn(getExtensionField(extension, 'draggable', context)), + code: callOrReturn(getExtensionField(extension, 'code', context)), + defining: callOrReturn(getExtensionField(extension, 'defining', context)), + isolating: callOrReturn(getExtensionField(extension, 'isolating', context)), attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => { return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }] })), }) - if (extension.config.parseHTML) { - schema.parseDOM = extension.config.parseHTML - .bind(context)() - ?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes)) + const parseHTML = callOrReturn(getExtensionField(extension, 'parseHTML', context)) + + if (parseHTML) { + schema.parseDOM = parseHTML + .map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes)) } - if (extension.config.renderHTML) { - schema.toDOM = node => (extension.config.renderHTML as Function)?.bind(context)({ + const renderHTML = getExtensionField(extension, 'renderHTML', context) + + if (renderHTML) { + schema.toDOM = node => renderHTML({ node, HTMLAttributes: getRenderedAttributes(node, extensionAttributes), }) @@ -94,9 +97,9 @@ export default function getSchema(extensions: Extensions): Schema { const marks = Object.fromEntries(markExtensions.map(extension => { const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name) - const context = createExtensionContext(extension, { + const context = { options: extension.options, - }) + } const extraMarkFields = markSchemaExtenders.reduce((fields, markSchemaExtender) => { const extraFields = callOrReturn(markSchemaExtender, context, extension) @@ -109,23 +112,26 @@ export default function getSchema(extensions: Extensions): Schema { const schema: MarkSpec = cleanUpSchemaItem({ ...extraMarkFields, - inclusive: callOrReturn(extension.config.inclusive, context), - excludes: callOrReturn(extension.config.excludes, context), - group: callOrReturn(extension.config.group, context), - spanning: callOrReturn(extension.config.spanning, context), + inclusive: callOrReturn(getExtensionField(extension, 'inclusive', context)), + excludes: callOrReturn(getExtensionField(extension, 'excludes', context)), + group: callOrReturn(getExtensionField(extension, 'group', context)), + spanning: callOrReturn(getExtensionField(extension, 'spanning', context)), attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => { return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }] })), }) - if (extension.config.parseHTML) { - schema.parseDOM = extension.config.parseHTML - .bind(context)() - ?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes)) + const parseHTML = callOrReturn(getExtensionField(extension, 'parseHTML', context)) + + if (parseHTML) { + schema.parseDOM = parseHTML + .map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes)) } - if (extension.config.renderHTML) { - schema.toDOM = mark => (extension.config.renderHTML as Function)?.bind(context)({ + const renderHTML = getExtensionField(extension, 'renderHTML', context) + + if (renderHTML) { + schema.toDOM = mark => renderHTML({ mark, HTMLAttributes: getRenderedAttributes(mark, extensionAttributes), }) diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index bb4309b6..f5492104 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -12,7 +12,7 @@ export { default as markPasteRule } from './pasteRules/markPasteRule' export { default as callOrReturn } from './utilities/callOrReturn' export { default as mergeAttributes } from './utilities/mergeAttributes' -export { default as createExtensionContext } from './helpers/createExtensionContext' +export { default as getExtensionField } from './helpers/getExtensionField' export { default as findChildren } from './helpers/findChildren' export { default as findParentNode } from './helpers/findParentNode' export { default as findParentNodeClosestToPos } from './helpers/findParentNodeClosestToPos' diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index db115d30..210e6434 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -14,17 +14,31 @@ import { Extension } from './Extension' import { Node } from './Node' import { Mark } from './Mark' import { Editor } from './Editor' -import { Commands } from '.' +import { + Commands, + ExtensionConfig, + NodeConfig, + MarkConfig, +} from '.' +export type AnyConfig = ExtensionConfig | NodeConfig | MarkConfig export type AnyExtension = Extension | Node | Mark export type Extensions = AnyExtension[] export type ParentConfig = Partial<{ - [P in keyof T]: Required[P] extends () => any + [P in keyof T]: Required[P] extends (...args: any) => any ? (...args: Parameters[P]>) => ReturnType[P]> : T[P] }> +export type RemoveThis = T extends (...args: any) => any + ? (...args: Parameters) => ReturnType + : T + +export type MaybeReturnType = T extends (...args: any) => any + ? ReturnType + : T + export interface EditorOptions { element: Element, content: Content, diff --git a/packages/core/src/utilities/callOrReturn.ts b/packages/core/src/utilities/callOrReturn.ts index abbefd9c..a87ee60f 100644 --- a/packages/core/src/utilities/callOrReturn.ts +++ b/packages/core/src/utilities/callOrReturn.ts @@ -1,3 +1,5 @@ +import { MaybeReturnType } from '../types' + /** * Optionally calls `value` as a function. * Otherwise it is returned directly. @@ -5,7 +7,7 @@ * @param context Optional context to bind to function. * @param props Optional props to pass to function. */ -export default function callOrReturn(value: any, context: any = undefined, ...props: any[]): any { +export default function callOrReturn(value: T, context: any = undefined, ...props: any[]): MaybeReturnType { if (typeof value === 'function') { if (context) { return value.bind(context)(...props) @@ -14,5 +16,5 @@ export default function callOrReturn(value: any, context: any = undefined, ...pr return value(...props) } - return value + return value as MaybeReturnType } diff --git a/packages/extension-gapcursor/src/gapcursor.ts b/packages/extension-gapcursor/src/gapcursor.ts index 6a32b73a..70ad0048 100644 --- a/packages/extension-gapcursor/src/gapcursor.ts +++ b/packages/extension-gapcursor/src/gapcursor.ts @@ -1,7 +1,7 @@ import { Extension, callOrReturn, - createExtensionContext, + getExtensionField, ParentConfig, } from '@tiptap/core' import { gapCursor } from 'prosemirror-gapcursor' @@ -31,12 +31,12 @@ export const Gapcursor = Extension.create({ }, extendNodeSchema(extension) { - const context = createExtensionContext(extension, { + const context = { options: extension.options, - }) + } return { - allowGapCursor: callOrReturn(extension.config.allowGapCursor, context) ?? null, + allowGapCursor: callOrReturn(getExtensionField(extension, 'allowGapCursor', context)) ?? null, } }, }) diff --git a/packages/extension-table/src/table.ts b/packages/extension-table/src/table.ts index d3020359..6e3a12cd 100644 --- a/packages/extension-table/src/table.ts +++ b/packages/extension-table/src/table.ts @@ -3,9 +3,9 @@ import { Command, ParentConfig, mergeAttributes, + getExtensionField, findParentNodeClosestToPos, callOrReturn, - createExtensionContext, } from '@tiptap/core' import { tableEditing, @@ -264,12 +264,12 @@ export const Table = Node.create({ }, extendNodeSchema(extension) { - const context = createExtensionContext(extension, { + const context = { options: extension.options, - }) + } return { - tableRole: callOrReturn(extension.config.tableRole, context), + tableRole: callOrReturn(getExtensionField(extension, 'tableRole', context)), } }, })