diff --git a/docs/src/demos/Examples/CodeBlockLanguage/index.vue b/docs/src/demos/Examples/CodeBlockLanguage/index.vue
index 6fe60eb8..b713b806 100644
--- a/docs/src/demos/Examples/CodeBlockLanguage/index.vue
+++ b/docs/src/demos/Examples/CodeBlockLanguage/index.vue
@@ -41,48 +41,13 @@ export default {
Document,
Paragraph,
Text,
-
- Extension
- .create({
- defaultOptions: {
- foo: 'foo0',
- },
- addProseMirrorPlugins() {
- console.log(0, this.options)
- return []
- },
- })
+ CodeBlockLowlight
.extend({
- // defaultOptions: {
- // foo: 'foo1',
- // },
- addProseMirrorPlugins() {
- console.log(1, this.options)
- // console.log(1, this.parentConfig.addProseMirrorPlugins)
- return this.parentConfig.addProseMirrorPlugins?.() || []
- // return []
+ addNodeView() {
+ return VueNodeViewRenderer(CodeBlockComponent)
},
})
- .extend({
- // defaultOptions: {
- // foo: 'foo2',
- // },
- })
- // .extend({
- // addProseMirrorPlugins() {
- // console.log(2, this.parentConfig.addProseMirrorPlugins)
- // // return this.parentConfig.addProseMirrorPlugins?.() || []
- // return []
- // },
- // })
-
- // CodeBlockLowlight
- // .extend({
- // addNodeView() {
- // return VueNodeViewRenderer(CodeBlockComponent)
- // },
- // })
- // // .configure({ lowlight }),
+ .configure({ lowlight }),
],
content: `
diff --git a/packages/core/src/ExtensionManager.ts b/packages/core/src/ExtensionManager.ts
index 8f9df005..0b8cef7d 100644
--- a/packages/core/src/ExtensionManager.ts
+++ b/packages/core/src/ExtensionManager.ts
@@ -4,8 +4,8 @@ import { inputRules as inputRulesPlugin } from 'prosemirror-inputrules'
import { EditorView, Decoration } from 'prosemirror-view'
import { Plugin } from 'prosemirror-state'
import { Editor } from './Editor'
-import { Extensions, NodeViewRenderer, RawCommands } from './types'
-import createExtensionContext from './helpers/createExtensionContext'
+import { Extensions, RawCommands, AnyConfig } from './types'
+import getExtensionField from './helpers/getExtensionField'
import getSchema from './helpers/getSchema'
import getSchemaTypeByName from './helpers/getSchemaTypeByName'
import getNodeType from './helpers/getNodeType'
@@ -13,6 +13,7 @@ import splitExtensions from './helpers/splitExtensions'
import getAttributesFromExtensions from './helpers/getAttributesFromExtensions'
import getRenderedAttributes from './helpers/getRenderedAttributes'
import callOrReturn from './utilities/callOrReturn'
+import { NodeConfig } from '.'
export default class ExtensionManager {
@@ -30,51 +31,51 @@ export default class ExtensionManager {
this.schema = getSchema(this.extensions)
this.extensions.forEach(extension => {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema),
- })
+ }
if (extension.type === 'mark') {
- const keepOnSplit = callOrReturn(extension.config.keepOnSplit, context) ?? true
+ const keepOnSplit = callOrReturn(getExtensionField(extension, 'keepOnSplit', context)) ?? true
if (keepOnSplit) {
this.splittableMarks.push(extension.config.name)
}
}
- if (typeof extension.config.onBeforeCreate === 'function') {
- this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context))
- }
+ // if (typeof extension.config.onBeforeCreate === 'function') {
+ // this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context))
+ // }
- if (typeof extension.config.onCreate === 'function') {
- this.editor.on('create', extension.config.onCreate.bind(context))
- }
+ // if (typeof extension.config.onCreate === 'function') {
+ // this.editor.on('create', extension.config.onCreate.bind(context))
+ // }
- if (typeof extension.config.onUpdate === 'function') {
- this.editor.on('update', extension.config.onUpdate.bind(context))
- }
+ // if (typeof extension.config.onUpdate === 'function') {
+ // this.editor.on('update', extension.config.onUpdate.bind(context))
+ // }
- if (typeof extension.config.onSelectionUpdate === 'function') {
- this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context))
- }
+ // if (typeof extension.config.onSelectionUpdate === 'function') {
+ // this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context))
+ // }
- if (typeof extension.config.onTransaction === 'function') {
- this.editor.on('transaction', extension.config.onTransaction.bind(context))
- }
+ // if (typeof extension.config.onTransaction === 'function') {
+ // this.editor.on('transaction', extension.config.onTransaction.bind(context))
+ // }
- if (typeof extension.config.onFocus === 'function') {
- this.editor.on('focus', extension.config.onFocus.bind(context))
- }
+ // if (typeof extension.config.onFocus === 'function') {
+ // this.editor.on('focus', extension.config.onFocus.bind(context))
+ // }
- if (typeof extension.config.onBlur === 'function') {
- this.editor.on('blur', extension.config.onBlur.bind(context))
- }
+ // if (typeof extension.config.onBlur === 'function') {
+ // this.editor.on('blur', extension.config.onBlur.bind(context))
+ // }
- if (typeof extension.config.onDestroy === 'function') {
- this.editor.on('destroy', extension.config.onDestroy.bind(context))
- }
+ // if (typeof extension.config.onDestroy === 'function') {
+ // this.editor.on('destroy', extension.config.onDestroy.bind(context))
+ // }
})
}
@@ -96,11 +97,11 @@ export default class ExtensionManager {
get commands(): RawCommands {
return this.extensions.reduce((commands, extension) => {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema),
- })
+ }
if (!extension.config.addCommands) {
return commands
@@ -108,7 +109,7 @@ export default class ExtensionManager {
return {
...commands,
- ...extension.config.addCommands.bind(context)(),
+ ...getExtensionField(extension, 'addCommands', context)(),
}
}, {} as RawCommands)
}
@@ -117,22 +118,34 @@ export default class ExtensionManager {
return [...this.extensions]
.reverse()
.map(extension => {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema),
- })
+ }
const plugins: Plugin[] = []
- if (extension.config.addKeyboardShortcuts) {
- const keyMapPlugin = keymap(extension.config.addKeyboardShortcuts.bind(context)())
+ const addKeyboardShortcuts = getExtensionField(
+ extension,
+ 'addKeyboardShortcuts',
+ context,
+ )
+
+ if (addKeyboardShortcuts) {
+ const keyMapPlugin = keymap(addKeyboardShortcuts())
plugins.push(keyMapPlugin)
}
- if (this.editor.options.enableInputRules && extension.config.addInputRules) {
- const inputRules = extension.config.addInputRules.bind(context)()
+ const addInputRules = getExtensionField(
+ extension,
+ 'addInputRules',
+ context,
+ )
+
+ if (this.editor.options.enableInputRules && addInputRules) {
+ const inputRules = addInputRules()
const inputRulePlugins = inputRules.length
? [inputRulesPlugin({ rules: inputRules })]
: []
@@ -140,46 +153,30 @@ export default class ExtensionManager {
plugins.push(...inputRulePlugins)
}
- if (this.editor.options.enablePasteRules && extension.config.addPasteRules) {
- const pasteRulePlugins = extension.config.addPasteRules.bind(context)()
+ const addPasteRules = getExtensionField(
+ extension,
+ 'addPasteRules',
+ context,
+ )
+
+ if (this.editor.options.enablePasteRules && addPasteRules) {
+ const pasteRulePlugins = addPasteRules()
plugins.push(...pasteRulePlugins)
}
- // console.log('has pm', extension.config.addProseMirrorPlugins, extension)
+ const addProseMirrorPlugins = getExtensionField(
+ extension,
+ 'addProseMirrorPlugins',
+ context,
+ )
- const getItem = (rootext: any, ext: any, field: string): any => {
- const realctx = createExtensionContext(ext, {
- options: rootext.options,
- // options: getItem(ext, 'defaultOptions'),
- editor: this.editor,
- type: getSchemaTypeByName(ext.config.name, this.schema),
- })
+ if (addProseMirrorPlugins) {
+ const proseMirrorPlugins = addProseMirrorPlugins()
- if (ext.config[field]) {
- if (typeof ext.config[field] === 'function') {
- return ext.config[field].bind(realctx)()
- }
-
- return ext.config[field]
- }
-
- if (ext.parent) {
- return getItem(rootext, ext.parent, field)
- }
-
- return undefined
+ plugins.push(...proseMirrorPlugins)
}
- // console.log('get PM', getItem(extension, 'addProseMirrorPlugins', context))
- const realPMP = getItem(extension, extension, 'addProseMirrorPlugins')
-
- // if (extension.config.addProseMirrorPlugins) {
- // const proseMirrorPlugins = extension.config.addProseMirrorPlugins.bind(context)()
-
- // plugins.push(...proseMirrorPlugins)
- // }
-
return plugins
})
.flat()
@@ -194,15 +191,24 @@ export default class ExtensionManager {
const { nodeExtensions } = splitExtensions(this.extensions)
return Object.fromEntries(nodeExtensions
- .filter(extension => !!extension.config.addNodeView)
+ .filter(extension => !!getExtensionField(extension, 'addNodeView'))
.map(extension => {
- const extensionAttributes = this.attributes.filter(attribute => attribute.type === extension.config.name)
- const context = createExtensionContext(extension, {
+ const name = getExtensionField(extension, 'name')
+ const extensionAttributes = this.attributes.filter(attribute => attribute.type === name)
+ const context = {
options: extension.options,
editor,
type: getNodeType(extension.config.name, this.schema),
- })
- const renderer = extension.config.addNodeView?.call(context) as NodeViewRenderer
+ }
+ const addNodeView = getExtensionField(
+ extension,
+ 'addNodeView',
+ context,
+ )
+
+ if (!addNodeView) {
+ return []
+ }
const nodeview = (
node: ProsemirrorNode,
@@ -212,7 +218,7 @@ export default class ExtensionManager {
) => {
const HTMLAttributes = getRenderedAttributes(node, extensionAttributes)
- return renderer({
+ return addNodeView()({
editor,
node,
getPos,
@@ -231,15 +237,21 @@ export default class ExtensionManager {
const { nodeExtensions } = splitExtensions(this.extensions)
return Object.fromEntries(nodeExtensions
- .filter(extension => !!extension.config.renderText)
+ .filter(extension => !!getExtensionField(extension, 'renderText'))
.map(extension => {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
editor,
type: getNodeType(extension.config.name, this.schema),
- })
+ }
- const textSerializer = (props: { node: ProsemirrorNode }) => extension.config.renderText?.call(context, props)
+ const renderText = getExtensionField(extension, 'renderText', context)
+
+ if (!renderText) {
+ return []
+ }
+
+ const textSerializer = (props: { node: ProsemirrorNode }) => renderText(props)
return [extension.config.name, textSerializer]
}))
diff --git a/packages/core/src/helpers/createExtensionContext.ts b/packages/core/src/helpers/createExtensionContext.ts
deleted file mode 100644
index 675c0d52..00000000
--- a/packages/core/src/helpers/createExtensionContext.ts
+++ /dev/null
@@ -1,101 +0,0 @@
-import { AnyExtension, AnyObject } from '../types'
-
-// export default function createExtensionContext(
-// extension: AnyExtension,
-// data: T,
-// ): T & { parentConfig: AnyObject } {
-// const context: any = {
-// ...data,
-// // get parentConfig() {
-// // return Object.fromEntries(Object.entries(extension.parentConfig).map(([key, value]) => {
-// // if (typeof value !== 'function') {
-// // return [key, value]
-// // }
-
-// // console.log('call', key)
-
-// // return [key, value.bind(context)]
-// // }))
-// // },
-
-// parentConfig: Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
-// if (typeof value !== 'function') {
-// return [key, value]
-// }
-
-// // console.log('call', key)
-
-// return [key, value.bind(data)]
-// })),
-
-// // get parentConfig() {
-// // console.log('parent', extension.parent)
-// // console.log('parent parent', extension.parent?.parent)
-
-// // return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
-// // if (typeof value !== 'function') {
-// // return [key, value]
-// // }
-
-// // // console.log('call', key)
-
-// // return [key, value.bind(context)]
-// // }))
-// // },
-
-// // parentConfig: null,
-// }
-
-// return context
-// }
-
-// export default function createExtensionContext(
-// extension: AnyExtension,
-// data: T,
-// // @ts-ignore
-// ): T & { parentConfig: AnyObject } {
-// const context: any = data
-
-// if (!extension.parent) {
-// context.parentConfig = {}
-
-// return context
-// }
-
-// // const bla = {
-// // ...(extension.parent.parent ? extension.parent.parent.config : {}),
-// // ...extension.parent.config,
-// // }
-
-// context.parentConfig = Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
-// if (typeof value !== 'function') {
-// return [key, value]
-// }
-
-// // console.log('call', key)
-
-// return [key, value.bind(createExtensionContext(extension.parent, data))]
-// }))
-
-// return context
-// }
-
-export default function createExtensionContext(
- extension: AnyExtension,
- data: T,
-): T & { parentConfig: AnyObject } {
- const context: any = {
- ...data,
- get parentConfig() {
- return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
- if (typeof value !== 'function') {
- return [key, value]
- }
-
- return [key, value.bind(context)]
- }))
- },
- }
-
- return context
-}
diff --git a/packages/core/src/helpers/getAttributesFromExtensions.ts b/packages/core/src/helpers/getAttributesFromExtensions.ts
index 467a6683..c70923e3 100644
--- a/packages/core/src/helpers/getAttributesFromExtensions.ts
+++ b/packages/core/src/helpers/getAttributesFromExtensions.ts
@@ -1,12 +1,14 @@
-import createExtensionContext from './createExtensionContext'
import splitExtensions from './splitExtensions'
+import getExtensionField from './getExtensionField'
import {
Extensions,
GlobalAttributes,
Attributes,
Attribute,
ExtensionAttribute,
+ AnyConfig,
} from '../types'
+import { NodeConfig, MarkConfig } from '..'
/**
* Get a list of all extension attributes defined in `addAttribute` and `addGlobalAttribute`.
@@ -25,15 +27,22 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext
}
extensions.forEach(extension => {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
- })
+ }
- if (!extension.config.addGlobalAttributes) {
+ const addGlobalAttributes = getExtensionField(
+ extension,
+ 'addGlobalAttributes',
+ context,
+ )
+
+ if (!addGlobalAttributes) {
return
}
- const globalAttributes = extension.config.addGlobalAttributes.bind(context)() as GlobalAttributes
+ // TODO: remove `as GlobalAttributes`
+ const globalAttributes = addGlobalAttributes() as GlobalAttributes
globalAttributes.forEach(globalAttribute => {
globalAttribute.types.forEach(type => {
@@ -54,21 +63,28 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext
})
nodeAndMarkExtensions.forEach(extension => {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
- })
+ }
- if (!extension.config.addAttributes) {
+ const addAttributes = getExtensionField(
+ extension,
+ 'addAttributes',
+ context,
+ )
+
+ if (!addAttributes) {
return
}
- const attributes = extension.config.addAttributes.bind(context)() as Attributes
+ // TODO: remove `as Attributes`
+ const attributes = addAttributes() as Attributes
Object
.entries(attributes)
.forEach(([name, attribute]) => {
extensionAttributes.push({
- type: extension.config.name,
+ type: getExtensionField(extension, 'name'),
name,
attribute: {
...defaultAttribute,
diff --git a/packages/core/src/helpers/getExtensionField.ts b/packages/core/src/helpers/getExtensionField.ts
new file mode 100644
index 00000000..53065884
--- /dev/null
+++ b/packages/core/src/helpers/getExtensionField.ts
@@ -0,0 +1,25 @@
+import { AnyExtension, AnyObject, RemoveThis } from '../types'
+
+export default function getExtensionField(
+ extension: AnyExtension,
+ field: string,
+ context: AnyObject = {},
+): RemoveThis {
+
+ if (extension.config[field] === undefined && extension.parent) {
+ return getExtensionField(extension.parent, field, context)
+ }
+
+ if (typeof extension.config[field] === 'function') {
+ const value = extension.config[field].bind({
+ ...context,
+ parent: extension.parent
+ ? getExtensionField(extension.parent, field, context)
+ : null,
+ })
+
+ return value
+ }
+
+ return extension.config[field]
+}
diff --git a/packages/core/src/helpers/getSchema.ts b/packages/core/src/helpers/getSchema.ts
index 4ad1ddde..05cc0074 100644
--- a/packages/core/src/helpers/getSchema.ts
+++ b/packages/core/src/helpers/getSchema.ts
@@ -1,13 +1,13 @@
import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model'
import { Extensions } from '../types'
import { ExtensionConfig, NodeConfig, MarkConfig } from '..'
-import createExtensionContext from './createExtensionContext'
import splitExtensions from './splitExtensions'
import getAttributesFromExtensions from './getAttributesFromExtensions'
import getRenderedAttributes from './getRenderedAttributes'
import isEmptyObject from '../utilities/isEmptyObject'
import injectExtensionAttributesToParseRule from './injectExtensionAttributesToParseRule'
import callOrReturn from '../utilities/callOrReturn'
+import getExtensionField from './getExtensionField'
function cleanUpSchemaItem(data: T) {
return Object.fromEntries(Object.entries(data).filter(([key, value]) => {
@@ -46,9 +46,9 @@ export default function getSchema(extensions: Extensions): Schema {
const nodes = Object.fromEntries(nodeExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
- })
+ }
const extraNodeFields = nodeSchemaExtenders.reduce((fields, nodeSchemaExtender) => {
const extraFields = callOrReturn(nodeSchemaExtender, context, extension)
@@ -61,29 +61,32 @@ export default function getSchema(extensions: Extensions): Schema {
const schema: NodeSpec = cleanUpSchemaItem({
...extraNodeFields,
- content: callOrReturn(extension.config.content, context),
- marks: callOrReturn(extension.config.marks, context),
- group: callOrReturn(extension.config.group, context),
- inline: callOrReturn(extension.config.inline, context),
- atom: callOrReturn(extension.config.atom, context),
- selectable: callOrReturn(extension.config.selectable, context),
- draggable: callOrReturn(extension.config.draggable, context),
- code: callOrReturn(extension.config.code, context),
- defining: callOrReturn(extension.config.defining, context),
- isolating: callOrReturn(extension.config.isolating, context),
+ content: callOrReturn(getExtensionField(extension, 'content', context)),
+ marks: callOrReturn(getExtensionField(extension, 'marks', context)),
+ group: callOrReturn(getExtensionField(extension, 'group', context)),
+ inline: callOrReturn(getExtensionField(extension, 'inline', context)),
+ atom: callOrReturn(getExtensionField(extension, 'atom', context)),
+ selectable: callOrReturn(getExtensionField(extension, 'selectable', context)),
+ draggable: callOrReturn(getExtensionField(extension, 'draggable', context)),
+ code: callOrReturn(getExtensionField(extension, 'code', context)),
+ defining: callOrReturn(getExtensionField(extension, 'defining', context)),
+ isolating: callOrReturn(getExtensionField(extension, 'isolating', context)),
attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => {
return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }]
})),
})
- if (extension.config.parseHTML) {
- schema.parseDOM = extension.config.parseHTML
- .bind(context)()
- ?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
+ const parseHTML = callOrReturn(getExtensionField(extension, 'parseHTML', context))
+
+ if (parseHTML) {
+ schema.parseDOM = parseHTML
+ .map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
}
- if (extension.config.renderHTML) {
- schema.toDOM = node => (extension.config.renderHTML as Function)?.bind(context)({
+ const renderHTML = getExtensionField(extension, 'renderHTML', context)
+
+ if (renderHTML) {
+ schema.toDOM = node => renderHTML({
node,
HTMLAttributes: getRenderedAttributes(node, extensionAttributes),
})
@@ -94,9 +97,9 @@ export default function getSchema(extensions: Extensions): Schema {
const marks = Object.fromEntries(markExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
- })
+ }
const extraMarkFields = markSchemaExtenders.reduce((fields, markSchemaExtender) => {
const extraFields = callOrReturn(markSchemaExtender, context, extension)
@@ -109,23 +112,26 @@ export default function getSchema(extensions: Extensions): Schema {
const schema: MarkSpec = cleanUpSchemaItem({
...extraMarkFields,
- inclusive: callOrReturn(extension.config.inclusive, context),
- excludes: callOrReturn(extension.config.excludes, context),
- group: callOrReturn(extension.config.group, context),
- spanning: callOrReturn(extension.config.spanning, context),
+ inclusive: callOrReturn(getExtensionField(extension, 'inclusive', context)),
+ excludes: callOrReturn(getExtensionField(extension, 'excludes', context)),
+ group: callOrReturn(getExtensionField(extension, 'group', context)),
+ spanning: callOrReturn(getExtensionField(extension, 'spanning', context)),
attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => {
return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }]
})),
})
- if (extension.config.parseHTML) {
- schema.parseDOM = extension.config.parseHTML
- .bind(context)()
- ?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
+ const parseHTML = callOrReturn(getExtensionField(extension, 'parseHTML', context))
+
+ if (parseHTML) {
+ schema.parseDOM = parseHTML
+ .map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
}
- if (extension.config.renderHTML) {
- schema.toDOM = mark => (extension.config.renderHTML as Function)?.bind(context)({
+ const renderHTML = getExtensionField(extension, 'renderHTML', context)
+
+ if (renderHTML) {
+ schema.toDOM = mark => renderHTML({
mark,
HTMLAttributes: getRenderedAttributes(mark, extensionAttributes),
})
diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts
index bb4309b6..f5492104 100644
--- a/packages/core/src/index.ts
+++ b/packages/core/src/index.ts
@@ -12,7 +12,7 @@ export { default as markPasteRule } from './pasteRules/markPasteRule'
export { default as callOrReturn } from './utilities/callOrReturn'
export { default as mergeAttributes } from './utilities/mergeAttributes'
-export { default as createExtensionContext } from './helpers/createExtensionContext'
+export { default as getExtensionField } from './helpers/getExtensionField'
export { default as findChildren } from './helpers/findChildren'
export { default as findParentNode } from './helpers/findParentNode'
export { default as findParentNodeClosestToPos } from './helpers/findParentNodeClosestToPos'
diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts
index db115d30..210e6434 100644
--- a/packages/core/src/types.ts
+++ b/packages/core/src/types.ts
@@ -14,17 +14,31 @@ import { Extension } from './Extension'
import { Node } from './Node'
import { Mark } from './Mark'
import { Editor } from './Editor'
-import { Commands } from '.'
+import {
+ Commands,
+ ExtensionConfig,
+ NodeConfig,
+ MarkConfig,
+} from '.'
+export type AnyConfig = ExtensionConfig | NodeConfig | MarkConfig
export type AnyExtension = Extension | Node | Mark
export type Extensions = AnyExtension[]
export type ParentConfig = Partial<{
- [P in keyof T]: Required[P] extends () => any
+ [P in keyof T]: Required[P] extends (...args: any) => any
? (...args: Parameters[P]>) => ReturnType[P]>
: T[P]
}>
+export type RemoveThis = T extends (...args: any) => any
+ ? (...args: Parameters) => ReturnType
+ : T
+
+export type MaybeReturnType = T extends (...args: any) => any
+ ? ReturnType
+ : T
+
export interface EditorOptions {
element: Element,
content: Content,
diff --git a/packages/core/src/utilities/callOrReturn.ts b/packages/core/src/utilities/callOrReturn.ts
index abbefd9c..a87ee60f 100644
--- a/packages/core/src/utilities/callOrReturn.ts
+++ b/packages/core/src/utilities/callOrReturn.ts
@@ -1,3 +1,5 @@
+import { MaybeReturnType } from '../types'
+
/**
* Optionally calls `value` as a function.
* Otherwise it is returned directly.
@@ -5,7 +7,7 @@
* @param context Optional context to bind to function.
* @param props Optional props to pass to function.
*/
-export default function callOrReturn(value: any, context: any = undefined, ...props: any[]): any {
+export default function callOrReturn(value: T, context: any = undefined, ...props: any[]): MaybeReturnType {
if (typeof value === 'function') {
if (context) {
return value.bind(context)(...props)
@@ -14,5 +16,5 @@ export default function callOrReturn(value: any, context: any = undefined, ...pr
return value(...props)
}
- return value
+ return value as MaybeReturnType
}
diff --git a/packages/extension-gapcursor/src/gapcursor.ts b/packages/extension-gapcursor/src/gapcursor.ts
index 6a32b73a..70ad0048 100644
--- a/packages/extension-gapcursor/src/gapcursor.ts
+++ b/packages/extension-gapcursor/src/gapcursor.ts
@@ -1,7 +1,7 @@
import {
Extension,
callOrReturn,
- createExtensionContext,
+ getExtensionField,
ParentConfig,
} from '@tiptap/core'
import { gapCursor } from 'prosemirror-gapcursor'
@@ -31,12 +31,12 @@ export const Gapcursor = Extension.create({
},
extendNodeSchema(extension) {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
- })
+ }
return {
- allowGapCursor: callOrReturn(extension.config.allowGapCursor, context) ?? null,
+ allowGapCursor: callOrReturn(getExtensionField(extension, 'allowGapCursor', context)) ?? null,
}
},
})
diff --git a/packages/extension-table/src/table.ts b/packages/extension-table/src/table.ts
index d3020359..6e3a12cd 100644
--- a/packages/extension-table/src/table.ts
+++ b/packages/extension-table/src/table.ts
@@ -3,9 +3,9 @@ import {
Command,
ParentConfig,
mergeAttributes,
+ getExtensionField,
findParentNodeClosestToPos,
callOrReturn,
- createExtensionContext,
} from '@tiptap/core'
import {
tableEditing,
@@ -264,12 +264,12 @@ export const Table = Node.create({
},
extendNodeSchema(extension) {
- const context = createExtensionContext(extension, {
+ const context = {
options: extension.options,
- })
+ }
return {
- tableRole: callOrReturn(extension.config.tableRole, context),
+ tableRole: callOrReturn(getExtensionField(extension, 'tableRole', context)),
}
},
})