wip: add getExtensionField

This commit is contained in:
Philipp Kühn
2021-04-15 21:14:33 +02:00
parent d194b90a61
commit 07bc40ce75
11 changed files with 213 additions and 274 deletions

View File

@@ -41,48 +41,13 @@ export default {
Document, Document,
Paragraph, Paragraph,
Text, Text,
CodeBlockLowlight
Extension
.create({
defaultOptions: {
foo: 'foo0',
},
addProseMirrorPlugins() {
console.log(0, this.options)
return []
},
})
.extend({ .extend({
// defaultOptions: { addNodeView() {
// foo: 'foo1', return VueNodeViewRenderer(CodeBlockComponent)
// },
addProseMirrorPlugins() {
console.log(1, this.options)
// console.log(1, this.parentConfig.addProseMirrorPlugins)
return this.parentConfig.addProseMirrorPlugins?.() || []
// return []
}, },
}) })
.extend({ .configure({ lowlight }),
// defaultOptions: {
// foo: 'foo2',
// },
})
// .extend({
// addProseMirrorPlugins() {
// console.log(2, this.parentConfig.addProseMirrorPlugins)
// // return this.parentConfig.addProseMirrorPlugins?.() || []
// return []
// },
// })
// CodeBlockLowlight
// .extend({
// addNodeView() {
// return VueNodeViewRenderer(CodeBlockComponent)
// },
// })
// // .configure({ lowlight }),
], ],
content: ` content: `
<p> <p>

View File

@@ -4,8 +4,8 @@ import { inputRules as inputRulesPlugin } from 'prosemirror-inputrules'
import { EditorView, Decoration } from 'prosemirror-view' import { EditorView, Decoration } from 'prosemirror-view'
import { Plugin } from 'prosemirror-state' import { Plugin } from 'prosemirror-state'
import { Editor } from './Editor' import { Editor } from './Editor'
import { Extensions, NodeViewRenderer, RawCommands } from './types' import { Extensions, RawCommands, AnyConfig } from './types'
import createExtensionContext from './helpers/createExtensionContext' import getExtensionField from './helpers/getExtensionField'
import getSchema from './helpers/getSchema' import getSchema from './helpers/getSchema'
import getSchemaTypeByName from './helpers/getSchemaTypeByName' import getSchemaTypeByName from './helpers/getSchemaTypeByName'
import getNodeType from './helpers/getNodeType' import getNodeType from './helpers/getNodeType'
@@ -13,6 +13,7 @@ import splitExtensions from './helpers/splitExtensions'
import getAttributesFromExtensions from './helpers/getAttributesFromExtensions' import getAttributesFromExtensions from './helpers/getAttributesFromExtensions'
import getRenderedAttributes from './helpers/getRenderedAttributes' import getRenderedAttributes from './helpers/getRenderedAttributes'
import callOrReturn from './utilities/callOrReturn' import callOrReturn from './utilities/callOrReturn'
import { NodeConfig } from '.'
export default class ExtensionManager { export default class ExtensionManager {
@@ -30,51 +31,51 @@ export default class ExtensionManager {
this.schema = getSchema(this.extensions) this.schema = getSchema(this.extensions)
this.extensions.forEach(extension => { this.extensions.forEach(extension => {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
editor: this.editor, editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema), type: getSchemaTypeByName(extension.config.name, this.schema),
}) }
if (extension.type === 'mark') { if (extension.type === 'mark') {
const keepOnSplit = callOrReturn(extension.config.keepOnSplit, context) ?? true const keepOnSplit = callOrReturn(getExtensionField(extension, 'keepOnSplit', context)) ?? true
if (keepOnSplit) { if (keepOnSplit) {
this.splittableMarks.push(extension.config.name) this.splittableMarks.push(extension.config.name)
} }
} }
if (typeof extension.config.onBeforeCreate === 'function') { // if (typeof extension.config.onBeforeCreate === 'function') {
this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context)) // this.editor.on('beforeCreate', extension.config.onBeforeCreate.bind(context))
} // }
if (typeof extension.config.onCreate === 'function') { // if (typeof extension.config.onCreate === 'function') {
this.editor.on('create', extension.config.onCreate.bind(context)) // this.editor.on('create', extension.config.onCreate.bind(context))
} // }
if (typeof extension.config.onUpdate === 'function') { // if (typeof extension.config.onUpdate === 'function') {
this.editor.on('update', extension.config.onUpdate.bind(context)) // this.editor.on('update', extension.config.onUpdate.bind(context))
} // }
if (typeof extension.config.onSelectionUpdate === 'function') { // if (typeof extension.config.onSelectionUpdate === 'function') {
this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context)) // this.editor.on('selectionUpdate', extension.config.onSelectionUpdate.bind(context))
} // }
if (typeof extension.config.onTransaction === 'function') { // if (typeof extension.config.onTransaction === 'function') {
this.editor.on('transaction', extension.config.onTransaction.bind(context)) // this.editor.on('transaction', extension.config.onTransaction.bind(context))
} // }
if (typeof extension.config.onFocus === 'function') { // if (typeof extension.config.onFocus === 'function') {
this.editor.on('focus', extension.config.onFocus.bind(context)) // this.editor.on('focus', extension.config.onFocus.bind(context))
} // }
if (typeof extension.config.onBlur === 'function') { // if (typeof extension.config.onBlur === 'function') {
this.editor.on('blur', extension.config.onBlur.bind(context)) // this.editor.on('blur', extension.config.onBlur.bind(context))
} // }
if (typeof extension.config.onDestroy === 'function') { // if (typeof extension.config.onDestroy === 'function') {
this.editor.on('destroy', extension.config.onDestroy.bind(context)) // this.editor.on('destroy', extension.config.onDestroy.bind(context))
} // }
}) })
} }
@@ -96,11 +97,11 @@ export default class ExtensionManager {
get commands(): RawCommands { get commands(): RawCommands {
return this.extensions.reduce((commands, extension) => { return this.extensions.reduce((commands, extension) => {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
editor: this.editor, editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema), type: getSchemaTypeByName(extension.config.name, this.schema),
}) }
if (!extension.config.addCommands) { if (!extension.config.addCommands) {
return commands return commands
@@ -108,7 +109,7 @@ export default class ExtensionManager {
return { return {
...commands, ...commands,
...extension.config.addCommands.bind(context)(), ...getExtensionField(extension, 'addCommands', context)(),
} }
}, {} as RawCommands) }, {} as RawCommands)
} }
@@ -117,22 +118,34 @@ export default class ExtensionManager {
return [...this.extensions] return [...this.extensions]
.reverse() .reverse()
.map(extension => { .map(extension => {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
editor: this.editor, editor: this.editor,
type: getSchemaTypeByName(extension.config.name, this.schema), type: getSchemaTypeByName(extension.config.name, this.schema),
}) }
const plugins: Plugin[] = [] const plugins: Plugin[] = []
if (extension.config.addKeyboardShortcuts) { const addKeyboardShortcuts = getExtensionField<AnyConfig['addKeyboardShortcuts']>(
const keyMapPlugin = keymap(extension.config.addKeyboardShortcuts.bind(context)()) extension,
'addKeyboardShortcuts',
context,
)
if (addKeyboardShortcuts) {
const keyMapPlugin = keymap(addKeyboardShortcuts())
plugins.push(keyMapPlugin) plugins.push(keyMapPlugin)
} }
if (this.editor.options.enableInputRules && extension.config.addInputRules) { const addInputRules = getExtensionField<AnyConfig['addInputRules']>(
const inputRules = extension.config.addInputRules.bind(context)() extension,
'addInputRules',
context,
)
if (this.editor.options.enableInputRules && addInputRules) {
const inputRules = addInputRules()
const inputRulePlugins = inputRules.length const inputRulePlugins = inputRules.length
? [inputRulesPlugin({ rules: inputRules })] ? [inputRulesPlugin({ rules: inputRules })]
: [] : []
@@ -140,46 +153,30 @@ export default class ExtensionManager {
plugins.push(...inputRulePlugins) plugins.push(...inputRulePlugins)
} }
if (this.editor.options.enablePasteRules && extension.config.addPasteRules) { const addPasteRules = getExtensionField<AnyConfig['addPasteRules']>(
const pasteRulePlugins = extension.config.addPasteRules.bind(context)() extension,
'addPasteRules',
context,
)
if (this.editor.options.enablePasteRules && addPasteRules) {
const pasteRulePlugins = addPasteRules()
plugins.push(...pasteRulePlugins) plugins.push(...pasteRulePlugins)
} }
// console.log('has pm', extension.config.addProseMirrorPlugins, extension) const addProseMirrorPlugins = getExtensionField<AnyConfig['addProseMirrorPlugins']>(
extension,
'addProseMirrorPlugins',
context,
)
const getItem = (rootext: any, ext: any, field: string): any => { if (addProseMirrorPlugins) {
const realctx = createExtensionContext(ext, { const proseMirrorPlugins = addProseMirrorPlugins()
options: rootext.options,
// options: getItem(ext, 'defaultOptions'),
editor: this.editor,
type: getSchemaTypeByName(ext.config.name, this.schema),
})
if (ext.config[field]) { plugins.push(...proseMirrorPlugins)
if (typeof ext.config[field] === 'function') {
return ext.config[field].bind(realctx)()
}
return ext.config[field]
}
if (ext.parent) {
return getItem(rootext, ext.parent, field)
}
return undefined
} }
// console.log('get PM', getItem(extension, 'addProseMirrorPlugins', context))
const realPMP = getItem(extension, extension, 'addProseMirrorPlugins')
// if (extension.config.addProseMirrorPlugins) {
// const proseMirrorPlugins = extension.config.addProseMirrorPlugins.bind(context)()
// plugins.push(...proseMirrorPlugins)
// }
return plugins return plugins
}) })
.flat() .flat()
@@ -194,15 +191,24 @@ export default class ExtensionManager {
const { nodeExtensions } = splitExtensions(this.extensions) const { nodeExtensions } = splitExtensions(this.extensions)
return Object.fromEntries(nodeExtensions return Object.fromEntries(nodeExtensions
.filter(extension => !!extension.config.addNodeView) .filter(extension => !!getExtensionField(extension, 'addNodeView'))
.map(extension => { .map(extension => {
const extensionAttributes = this.attributes.filter(attribute => attribute.type === extension.config.name) const name = getExtensionField<NodeConfig['name']>(extension, 'name')
const context = createExtensionContext(extension, { const extensionAttributes = this.attributes.filter(attribute => attribute.type === name)
const context = {
options: extension.options, options: extension.options,
editor, editor,
type: getNodeType(extension.config.name, this.schema), type: getNodeType(extension.config.name, this.schema),
}) }
const renderer = extension.config.addNodeView?.call(context) as NodeViewRenderer const addNodeView = getExtensionField<NodeConfig['addNodeView']>(
extension,
'addNodeView',
context,
)
if (!addNodeView) {
return []
}
const nodeview = ( const nodeview = (
node: ProsemirrorNode, node: ProsemirrorNode,
@@ -212,7 +218,7 @@ export default class ExtensionManager {
) => { ) => {
const HTMLAttributes = getRenderedAttributes(node, extensionAttributes) const HTMLAttributes = getRenderedAttributes(node, extensionAttributes)
return renderer({ return addNodeView()({
editor, editor,
node, node,
getPos, getPos,
@@ -231,15 +237,21 @@ export default class ExtensionManager {
const { nodeExtensions } = splitExtensions(this.extensions) const { nodeExtensions } = splitExtensions(this.extensions)
return Object.fromEntries(nodeExtensions return Object.fromEntries(nodeExtensions
.filter(extension => !!extension.config.renderText) .filter(extension => !!getExtensionField(extension, 'renderText'))
.map(extension => { .map(extension => {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
editor, editor,
type: getNodeType(extension.config.name, this.schema), type: getNodeType(extension.config.name, this.schema),
}) }
const textSerializer = (props: { node: ProsemirrorNode }) => extension.config.renderText?.call(context, props) const renderText = getExtensionField<NodeConfig['renderText']>(extension, 'renderText', context)
if (!renderText) {
return []
}
const textSerializer = (props: { node: ProsemirrorNode }) => renderText(props)
return [extension.config.name, textSerializer] return [extension.config.name, textSerializer]
})) }))

View File

@@ -1,101 +0,0 @@
import { AnyExtension, AnyObject } from '../types'
// export default function createExtensionContext<T>(
// extension: AnyExtension,
// data: T,
// ): T & { parentConfig: AnyObject } {
// const context: any = {
// ...data,
// // get parentConfig() {
// // return Object.fromEntries(Object.entries(extension.parentConfig).map(([key, value]) => {
// // if (typeof value !== 'function') {
// // return [key, value]
// // }
// // console.log('call', key)
// // return [key, value.bind(context)]
// // }))
// // },
// parentConfig: Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
// if (typeof value !== 'function') {
// return [key, value]
// }
// // console.log('call', key)
// return [key, value.bind(data)]
// })),
// // get parentConfig() {
// // console.log('parent', extension.parent)
// // console.log('parent parent', extension.parent?.parent)
// // return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
// // if (typeof value !== 'function') {
// // return [key, value]
// // }
// // // console.log('call', key)
// // return [key, value.bind(context)]
// // }))
// // },
// // parentConfig: null,
// }
// return context
// }
// export default function createExtensionContext<T>(
// extension: AnyExtension,
// data: T,
// // @ts-ignore
// ): T & { parentConfig: AnyObject } {
// const context: any = data
// if (!extension.parent) {
// context.parentConfig = {}
// return context
// }
// // const bla = {
// // ...(extension.parent.parent ? extension.parent.parent.config : {}),
// // ...extension.parent.config,
// // }
// context.parentConfig = Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
// if (typeof value !== 'function') {
// return [key, value]
// }
// // console.log('call', key)
// return [key, value.bind(createExtensionContext(extension.parent, data))]
// }))
// return context
// }
export default function createExtensionContext<T>(
extension: AnyExtension,
data: T,
): T & { parentConfig: AnyObject } {
const context: any = {
...data,
get parentConfig() {
return Object.fromEntries(Object.entries(extension.parent.config).map(([key, value]) => {
if (typeof value !== 'function') {
return [key, value]
}
return [key, value.bind(context)]
}))
},
}
return context
}

View File

@@ -1,12 +1,14 @@
import createExtensionContext from './createExtensionContext'
import splitExtensions from './splitExtensions' import splitExtensions from './splitExtensions'
import getExtensionField from './getExtensionField'
import { import {
Extensions, Extensions,
GlobalAttributes, GlobalAttributes,
Attributes, Attributes,
Attribute, Attribute,
ExtensionAttribute, ExtensionAttribute,
AnyConfig,
} from '../types' } from '../types'
import { NodeConfig, MarkConfig } from '..'
/** /**
* Get a list of all extension attributes defined in `addAttribute` and `addGlobalAttribute`. * Get a list of all extension attributes defined in `addAttribute` and `addGlobalAttribute`.
@@ -25,15 +27,22 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext
} }
extensions.forEach(extension => { extensions.forEach(extension => {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
}) }
if (!extension.config.addGlobalAttributes) { const addGlobalAttributes = getExtensionField<AnyConfig['addGlobalAttributes']>(
extension,
'addGlobalAttributes',
context,
)
if (!addGlobalAttributes) {
return return
} }
const globalAttributes = extension.config.addGlobalAttributes.bind(context)() as GlobalAttributes // TODO: remove `as GlobalAttributes`
const globalAttributes = addGlobalAttributes() as GlobalAttributes
globalAttributes.forEach(globalAttribute => { globalAttributes.forEach(globalAttribute => {
globalAttribute.types.forEach(type => { globalAttribute.types.forEach(type => {
@@ -54,21 +63,28 @@ export default function getAttributesFromExtensions(extensions: Extensions): Ext
}) })
nodeAndMarkExtensions.forEach(extension => { nodeAndMarkExtensions.forEach(extension => {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
}) }
if (!extension.config.addAttributes) { const addAttributes = getExtensionField<NodeConfig['addAttributes'] | MarkConfig['addAttributes']>(
extension,
'addAttributes',
context,
)
if (!addAttributes) {
return return
} }
const attributes = extension.config.addAttributes.bind(context)() as Attributes // TODO: remove `as Attributes`
const attributes = addAttributes() as Attributes
Object Object
.entries(attributes) .entries(attributes)
.forEach(([name, attribute]) => { .forEach(([name, attribute]) => {
extensionAttributes.push({ extensionAttributes.push({
type: extension.config.name, type: getExtensionField(extension, 'name'),
name, name,
attribute: { attribute: {
...defaultAttribute, ...defaultAttribute,

View File

@@ -0,0 +1,25 @@
import { AnyExtension, AnyObject, RemoveThis } from '../types'
export default function getExtensionField<T = any>(
extension: AnyExtension,
field: string,
context: AnyObject = {},
): RemoveThis<T> {
if (extension.config[field] === undefined && extension.parent) {
return getExtensionField(extension.parent, field, context)
}
if (typeof extension.config[field] === 'function') {
const value = extension.config[field].bind({
...context,
parent: extension.parent
? getExtensionField(extension.parent, field, context)
: null,
})
return value
}
return extension.config[field]
}

View File

@@ -1,13 +1,13 @@
import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model' import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model'
import { Extensions } from '../types' import { Extensions } from '../types'
import { ExtensionConfig, NodeConfig, MarkConfig } from '..' import { ExtensionConfig, NodeConfig, MarkConfig } from '..'
import createExtensionContext from './createExtensionContext'
import splitExtensions from './splitExtensions' import splitExtensions from './splitExtensions'
import getAttributesFromExtensions from './getAttributesFromExtensions' import getAttributesFromExtensions from './getAttributesFromExtensions'
import getRenderedAttributes from './getRenderedAttributes' import getRenderedAttributes from './getRenderedAttributes'
import isEmptyObject from '../utilities/isEmptyObject' import isEmptyObject from '../utilities/isEmptyObject'
import injectExtensionAttributesToParseRule from './injectExtensionAttributesToParseRule' import injectExtensionAttributesToParseRule from './injectExtensionAttributesToParseRule'
import callOrReturn from '../utilities/callOrReturn' import callOrReturn from '../utilities/callOrReturn'
import getExtensionField from './getExtensionField'
function cleanUpSchemaItem<T>(data: T) { function cleanUpSchemaItem<T>(data: T) {
return Object.fromEntries(Object.entries(data).filter(([key, value]) => { return Object.fromEntries(Object.entries(data).filter(([key, value]) => {
@@ -46,9 +46,9 @@ export default function getSchema(extensions: Extensions): Schema {
const nodes = Object.fromEntries(nodeExtensions.map(extension => { const nodes = Object.fromEntries(nodeExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name) const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
}) }
const extraNodeFields = nodeSchemaExtenders.reduce((fields, nodeSchemaExtender) => { const extraNodeFields = nodeSchemaExtenders.reduce((fields, nodeSchemaExtender) => {
const extraFields = callOrReturn(nodeSchemaExtender, context, extension) const extraFields = callOrReturn(nodeSchemaExtender, context, extension)
@@ -61,29 +61,32 @@ export default function getSchema(extensions: Extensions): Schema {
const schema: NodeSpec = cleanUpSchemaItem({ const schema: NodeSpec = cleanUpSchemaItem({
...extraNodeFields, ...extraNodeFields,
content: callOrReturn(extension.config.content, context), content: callOrReturn(getExtensionField<NodeConfig['content']>(extension, 'content', context)),
marks: callOrReturn(extension.config.marks, context), marks: callOrReturn(getExtensionField<NodeConfig['marks']>(extension, 'marks', context)),
group: callOrReturn(extension.config.group, context), group: callOrReturn(getExtensionField<NodeConfig['group']>(extension, 'group', context)),
inline: callOrReturn(extension.config.inline, context), inline: callOrReturn(getExtensionField<NodeConfig['inline']>(extension, 'inline', context)),
atom: callOrReturn(extension.config.atom, context), atom: callOrReturn(getExtensionField<NodeConfig['atom']>(extension, 'atom', context)),
selectable: callOrReturn(extension.config.selectable, context), selectable: callOrReturn(getExtensionField<NodeConfig['selectable']>(extension, 'selectable', context)),
draggable: callOrReturn(extension.config.draggable, context), draggable: callOrReturn(getExtensionField<NodeConfig['draggable']>(extension, 'draggable', context)),
code: callOrReturn(extension.config.code, context), code: callOrReturn(getExtensionField<NodeConfig['code']>(extension, 'code', context)),
defining: callOrReturn(extension.config.defining, context), defining: callOrReturn(getExtensionField<NodeConfig['defining']>(extension, 'defining', context)),
isolating: callOrReturn(extension.config.isolating, context), isolating: callOrReturn(getExtensionField<NodeConfig['isolating']>(extension, 'isolating', context)),
attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => { attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => {
return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }] return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }]
})), })),
}) })
if (extension.config.parseHTML) { const parseHTML = callOrReturn(getExtensionField<NodeConfig['parseHTML']>(extension, 'parseHTML', context))
schema.parseDOM = extension.config.parseHTML
.bind(context)() if (parseHTML) {
?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes)) schema.parseDOM = parseHTML
.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
} }
if (extension.config.renderHTML) { const renderHTML = getExtensionField<NodeConfig['renderHTML']>(extension, 'renderHTML', context)
schema.toDOM = node => (extension.config.renderHTML as Function)?.bind(context)({
if (renderHTML) {
schema.toDOM = node => renderHTML({
node, node,
HTMLAttributes: getRenderedAttributes(node, extensionAttributes), HTMLAttributes: getRenderedAttributes(node, extensionAttributes),
}) })
@@ -94,9 +97,9 @@ export default function getSchema(extensions: Extensions): Schema {
const marks = Object.fromEntries(markExtensions.map(extension => { const marks = Object.fromEntries(markExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name) const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
}) }
const extraMarkFields = markSchemaExtenders.reduce((fields, markSchemaExtender) => { const extraMarkFields = markSchemaExtenders.reduce((fields, markSchemaExtender) => {
const extraFields = callOrReturn(markSchemaExtender, context, extension) const extraFields = callOrReturn(markSchemaExtender, context, extension)
@@ -109,23 +112,26 @@ export default function getSchema(extensions: Extensions): Schema {
const schema: MarkSpec = cleanUpSchemaItem({ const schema: MarkSpec = cleanUpSchemaItem({
...extraMarkFields, ...extraMarkFields,
inclusive: callOrReturn(extension.config.inclusive, context), inclusive: callOrReturn(getExtensionField<NodeConfig['inclusive']>(extension, 'inclusive', context)),
excludes: callOrReturn(extension.config.excludes, context), excludes: callOrReturn(getExtensionField<NodeConfig['excludes']>(extension, 'excludes', context)),
group: callOrReturn(extension.config.group, context), group: callOrReturn(getExtensionField<NodeConfig['group']>(extension, 'group', context)),
spanning: callOrReturn(extension.config.spanning, context), spanning: callOrReturn(getExtensionField<NodeConfig['spanning']>(extension, 'spanning', context)),
attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => { attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => {
return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }] return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }]
})), })),
}) })
if (extension.config.parseHTML) { const parseHTML = callOrReturn(getExtensionField<MarkConfig['parseHTML']>(extension, 'parseHTML', context))
schema.parseDOM = extension.config.parseHTML
.bind(context)() if (parseHTML) {
?.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes)) schema.parseDOM = parseHTML
.map(parseRule => injectExtensionAttributesToParseRule(parseRule, extensionAttributes))
} }
if (extension.config.renderHTML) { const renderHTML = getExtensionField<MarkConfig['renderHTML']>(extension, 'renderHTML', context)
schema.toDOM = mark => (extension.config.renderHTML as Function)?.bind(context)({
if (renderHTML) {
schema.toDOM = mark => renderHTML({
mark, mark,
HTMLAttributes: getRenderedAttributes(mark, extensionAttributes), HTMLAttributes: getRenderedAttributes(mark, extensionAttributes),
}) })

View File

@@ -12,7 +12,7 @@ export { default as markPasteRule } from './pasteRules/markPasteRule'
export { default as callOrReturn } from './utilities/callOrReturn' export { default as callOrReturn } from './utilities/callOrReturn'
export { default as mergeAttributes } from './utilities/mergeAttributes' export { default as mergeAttributes } from './utilities/mergeAttributes'
export { default as createExtensionContext } from './helpers/createExtensionContext' export { default as getExtensionField } from './helpers/getExtensionField'
export { default as findChildren } from './helpers/findChildren' export { default as findChildren } from './helpers/findChildren'
export { default as findParentNode } from './helpers/findParentNode' export { default as findParentNode } from './helpers/findParentNode'
export { default as findParentNodeClosestToPos } from './helpers/findParentNodeClosestToPos' export { default as findParentNodeClosestToPos } from './helpers/findParentNodeClosestToPos'

View File

@@ -14,17 +14,31 @@ import { Extension } from './Extension'
import { Node } from './Node' import { Node } from './Node'
import { Mark } from './Mark' import { Mark } from './Mark'
import { Editor } from './Editor' import { Editor } from './Editor'
import { Commands } from '.' import {
Commands,
ExtensionConfig,
NodeConfig,
MarkConfig,
} from '.'
export type AnyConfig = ExtensionConfig | NodeConfig | MarkConfig
export type AnyExtension = Extension | Node | Mark export type AnyExtension = Extension | Node | Mark
export type Extensions = AnyExtension[] export type Extensions = AnyExtension[]
export type ParentConfig<T> = Partial<{ export type ParentConfig<T> = Partial<{
[P in keyof T]: Required<T>[P] extends () => any [P in keyof T]: Required<T>[P] extends (...args: any) => any
? (...args: Parameters<Required<T>[P]>) => ReturnType<Required<T>[P]> ? (...args: Parameters<Required<T>[P]>) => ReturnType<Required<T>[P]>
: T[P] : T[P]
}> }>
export type RemoveThis<T> = T extends (...args: any) => any
? (...args: Parameters<T>) => ReturnType<T>
: T
export type MaybeReturnType<T> = T extends (...args: any) => any
? ReturnType<T>
: T
export interface EditorOptions { export interface EditorOptions {
element: Element, element: Element,
content: Content, content: Content,

View File

@@ -1,3 +1,5 @@
import { MaybeReturnType } from '../types'
/** /**
* Optionally calls `value` as a function. * Optionally calls `value` as a function.
* Otherwise it is returned directly. * Otherwise it is returned directly.
@@ -5,7 +7,7 @@
* @param context Optional context to bind to function. * @param context Optional context to bind to function.
* @param props Optional props to pass to function. * @param props Optional props to pass to function.
*/ */
export default function callOrReturn(value: any, context: any = undefined, ...props: any[]): any { export default function callOrReturn<T>(value: T, context: any = undefined, ...props: any[]): MaybeReturnType<T> {
if (typeof value === 'function') { if (typeof value === 'function') {
if (context) { if (context) {
return value.bind(context)(...props) return value.bind(context)(...props)
@@ -14,5 +16,5 @@ export default function callOrReturn(value: any, context: any = undefined, ...pr
return value(...props) return value(...props)
} }
return value return value as MaybeReturnType<T>
} }

View File

@@ -1,7 +1,7 @@
import { import {
Extension, Extension,
callOrReturn, callOrReturn,
createExtensionContext, getExtensionField,
ParentConfig, ParentConfig,
} from '@tiptap/core' } from '@tiptap/core'
import { gapCursor } from 'prosemirror-gapcursor' import { gapCursor } from 'prosemirror-gapcursor'
@@ -31,12 +31,12 @@ export const Gapcursor = Extension.create({
}, },
extendNodeSchema(extension) { extendNodeSchema(extension) {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
}) }
return { return {
allowGapCursor: callOrReturn(extension.config.allowGapCursor, context) ?? null, allowGapCursor: callOrReturn(getExtensionField(extension, 'allowGapCursor', context)) ?? null,
} }
}, },
}) })

View File

@@ -3,9 +3,9 @@ import {
Command, Command,
ParentConfig, ParentConfig,
mergeAttributes, mergeAttributes,
getExtensionField,
findParentNodeClosestToPos, findParentNodeClosestToPos,
callOrReturn, callOrReturn,
createExtensionContext,
} from '@tiptap/core' } from '@tiptap/core'
import { import {
tableEditing, tableEditing,
@@ -264,12 +264,12 @@ export const Table = Node.create<TableOptions>({
}, },
extendNodeSchema(extension) { extendNodeSchema(extension) {
const context = createExtensionContext(extension, { const context = {
options: extension.options, options: extension.options,
}) }
return { return {
tableRole: callOrReturn(extension.config.tableRole, context), tableRole: callOrReturn(getExtensionField(extension, 'tableRole', context)),
} }
}, },
}) })